2019-03-05深度学习——搭建一个简单的KnnClassi
2019-03-05 本文已影响0人
Hie_9e55
KNN思想
NN选出与目标图片距离(范数)最近的一张图片
KNN选出与目标图片距离(范数)最近的K张图片,并统计K张图中出现次数最多的类型,即为预测类型
代码实现
分为两个部分
第一部分是knn模型
第二部分是模型的使用
- KNN.py
# KNN.py
# 导入所需要的库
# 这里我们需要使用numpy库进行矩阵运算
# 使用collections中的Counter
import numpy as np
from collections import Counter
class KNearestNeighbor:
def __init__(self, k = 7):
self.k = k
# 训练模型,KNN只是简单的导入即可,因为K是一个超参数,X是数据,n*3072,Y是数据标签,n*1
def train(self, X, y):
self.Xtr = X
self.ytr = y
# 使用模型进行预测,X是test集的数据
def predict(self, X):
num_test = X.shape[0]# test数据个数
Ypred = np.zeros((num_test, len(self.k)))# 初始化预测结果
for i in range(num_test):# 每次迭代一张图片
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)# 计算范数一
temp = np.argsort(distances)# numpy.argsort()函数返回的是数组值从小到大的索引值
index = []
for j in range(len(self.k)):
# 取出前k个下标
id = temp[0:self.k[j]]
# 取出k个下标所对应的类型
temp_y = np.array(self.ytr)[id]
# 取出k个下标中出现次数最多的label
index.append(Counter(temp_y).most_common(1)[0][0])
print(np.array(index), i)
Ypred[i] = np.array(index)# 记录第i张图片在进行knn之后的label
return Ypred
- runKNN.py
# runKNN.py
# 导入所需要的库
# pickle用于解压数据
# matplotlib用于绘图
# numpy用于矩阵运算
# KNN用于预测label
import pickle
from matplotlib import pyplot as plt
import numpy as np
from KNN import KNearestNeighbor
# 数据地址
filename1 = 'D:/Download/cifar-10-batches-py/data_batch_1'
filename2 = 'D:/Download/cifar-10-batches-py/data_batch_2'
filename3 = 'D:/Download/cifar-10-batches-py/data_batch_3'
filename4 = 'D:/Download/cifar-10-batches-py/data_batch_4'
filename5 = 'D:/Download/cifar-10-batches-py/data_batch_5'
filename_test = 'D:/Download/cifar-10-batches-py/test_batch'
# 定义导入数据的函数
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
# 使用pickle暴力导入数据
data = []
data.append(load_file(filename1))
data.append(load_file(filename2))
data.append(load_file(filename3))
data.append(load_file(filename4))
data.append(load_file(filename5))
test_batch = load_file(filename_test)
# 作业要求的K值
k = [1,3,5,7,9]
# 初始化几个会用到的list
result = []
validation = []
# 建立模型
net = KNearestNeighbor(k)
# 5个batch分别迭代
for i in range(5):
# 训练
net.train(data[i]['data'], data[i]['labels'])
# 预测
result.append(net.predict(test_batch['data']))
# 计算预测结果与实际label之间的误差
temp = result[i] - np.array(test_batch['labels']).reshape((np.array(test_batch['labels']).shape[0],1))
temp[temp != 0] = 1
# 计算准确度
validation.append(1 - (sum(abs(temp)) / test_batch['data'].shape[0]))
print('batch', i, '', validation[i])
print(validation)
# 绘图
plt.title("Cross Validation")
plt.xlabel("k")
plt.ylabel("validation")
plt.axis([0, 10, 0.2, 0.26])
validation = np.matrix(validation).T
ave = np.sum(validation, axis = 1) / len(k)
plt.plot(k, ave)
fig = plt.plot(k, validation, 'ro')
plt.setp(fig, color='b')
plt.savefig('fig')
plt.show()
-
KNN结果
可以看到当k=7的时候,准确度最高(其实也高不到哪儿去)
fig.png