内容简介
本实验介绍了使用MindSpore框架实现K近邻算法(KNN)对红酒数据集进行分类的全过程。通过数据读取、预处理、模型构建与预测,展示了KNN算法在红酒数据集上的应用。实验中详细解释了KNN的原理、距离度量方式及其在分类问题中的应用,最后通过验证集评估模型性能,验证了KNN算法在该3分类任务上的有效性。
实验代码及注释
# 导入必要的库 import os import csv import numpy as np import matplotlib.pyplot as plt import mindspore as ms from mindspore import nn, ops # 设置MindSpore的运行环境 ms.set_context(device_target="CPU") # 读取数据集 with open('wine.data') as csv_file: data = list(csv.reader(csv_file, delimiter=',')) print(data[56:62]+data[130:133]) # 数据处理 # 将数据集的13个属性作为自变量 X,将3个类别作为因变量 Y X = np.array([[float(x) for x in s[1:]] for s in data[:178]], np.float32) Y = np.array([s[0] for s in data[:178]], np.int32) # 可视化样本分布 attrs = ['Alcohol', 'Malic acid', 'Ash', 'Alcalinity of ash', 'Magnesium', 'Total phenols', 'Flavanoids', 'Nonflavanoid phenols', 'Proanthocyanins', 'Color intensity', 'Hue', 'OD280/OD315 of diluted wines', 'Proline'] plt.figure(figsize=(10, 8)) for i in range(0, 4): plt.subplot(2, 2, i+1) a1, a2 = 2 * i, 2 * i + 1 plt.scatter(X[:59, a1], X[:59, a2], label='1') plt.scatter(X[59:130, a1], X[59:130, a2], label='2') plt.scatter(X[130:, a1], X[130:, a2], label='3') plt.xlabel(attrs[a1]) plt.ylabel(attrs[a2]) plt.legend() plt.show() # 划分训练集和测试集 train_idx = np.random.choice(178, 128, replace=False) test_idx = np.array(list(set(range(178)) - set(train_idx))) X_train, Y_train = X[train_idx], Y[train_idx] X_test, Y_test = X[test_idx], Y[test_idx] # 构建KNN模型 class KnnNet(nn.Cell): def __init__(self, k): super(KnnNet, self).__init__() self.k = k def construct(self, x, X_train): x_tile = ops.tile(x, (128, 1)) # 平铺输入x以匹配X_train中的样本数 square_diff = ops.square(x_tile - X_train) square_dist = ops.sum(square_diff, 1) dist = ops.sqrt(square_dist) values, indices = ops.topk(-dist, self.k) # -dist表示值越大,样本就越接近 return indices def knn(knn_net, x, X_train, Y_train): x, X_train = ms.Tensor(x), ms.Tensor(X_train) indices = knn_net(x, X_train) topk_cls = [0]*len(indices.asnumpy()) for idx in indices.asnumpy(): topk_cls[Y_train[idx]] += 1 cls = np.argmax(topk_cls) return cls # 模型预测 acc = 0 knn_net = KnnNet(5) for x, y in zip(X_test, Y_test): pred = knn(knn_net, x, X_train, Y_train) acc += (pred == y) print('label: %d, prediction: %s' % (y, pred)) print('Validation accuracy is %f' % (acc/len(Y_test)))
学习心得
通过本次实验,我深入了解了K近邻算法(KNN)及其在分类任务中的应用。KNN是一种基于实例的学习算法,利用训练样本的多数表决结果来对新样本进行分类。它具有简单、直观的优点,但在大规模数据集上计算复杂度较高,因此需要在应用时进行适当的优化和改进。
在实验过程中,首先进行了数据读取与处理。Wine数据集包含13个属性,每个属性都对分类结果有不同程度的影响。通过可视化展示了不同类别样本在某两个属性上的分布,帮助我们直观地理解数据的可分性。接下来,通过划分训练集和测试集,保证了模型的训练和验证能够在不同的数据上进行,从而提高模型的泛化能力。
模型构建部分,通过使用MindSpore框架,利用其提供的高效算子如tile、square、ReduceSum等,构建了KNN模型。在计算距离时,选择了欧氏距离,并使用TopK算子找出距离最近的k个邻居。对于分类决策,采用了多数表决的方式,即统计k个邻居中每个类别的数量,选择最多的类别作为预测结果。
在验证阶段,取k值为5进行模型预测,验证精度约为70%。虽然准确率不算特别高,但对于一个简单的3分类任务,KNN算法仍然展现出了其有效性。通过调整k值或者加入样本权重,可以进一步优化模型性能。
label: 2, prediction: 3
label: 3, prediction: 2
label: 1, prediction: 1
label: 3, prediction: 3
label: 1, prediction: 1
label: 1, prediction: 1
label: 3, prediction: 3
label: 3, prediction: 3
label: 1, prediction: 1
label: 1, prediction: 1
label: 1, prediction: 1
label: 3, prediction: 3
label: 1, prediction: 1
label: 1, prediction: 1
label: 3, prediction: 1
label: 1, prediction: 3
label: 1, prediction: 3
label: 3, prediction: 3
label: 1, prediction: 1
label: 3, prediction: 3
label: 3, prediction: 3
label: 3, prediction: 3
label: 1, prediction: 1
label: 3, prediction: 3
label: 3, prediction: 1
label: 1, prediction: 1
label: 1, prediction: 1
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 3
label: 2, prediction: 1
label: 2, prediction: 1
label: 2, prediction: 3
label: 2, prediction: 1
label: 2, prediction: 3
label: 2, prediction: 2
label: 2, prediction: 3
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 3
label: 2, prediction: 3
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
label: 2, prediction: 2
Validation accuracy is 0.700000
本次实验不仅让我掌握了KNN算法的实现过程,还了解了MindSpore框架在机器学习任务中的应用。通过实验操作,进一步巩固了机器学习理论知识,提升了编程实战能力。同时,也深刻认识到在处理实际问题时,数据预处理和特征工程的重要性。