2019-03-05深度学习——搭建一个简单的KnnClassi

2019-03-05  本文已影响0人  Hie_9e55

KNN思想

NN选出与目标图片距离(范数)最近的一张图片
KNN选出与目标图片距离(范数)最近的K张图片,并统计K张图中出现次数最多的类型,即为预测类型

代码实现

分为两个部分
第一部分是knn模型
第二部分是模型的使用

  1. KNN.py
# KNN.py
# 导入所需要的库
# 这里我们需要使用numpy库进行矩阵运算
# 使用collections中的Counter
import numpy as np
from collections import Counter

class KNearestNeighbor:

    def __init__(self, k = 7):
        self.k = k

    # 训练模型,KNN只是简单的导入即可,因为K是一个超参数,X是数据,n*3072,Y是数据标签,n*1
    def train(self, X, y):
        self.Xtr = X
        self.ytr = y

    # 使用模型进行预测,X是test集的数据
    def predict(self, X):
        num_test = X.shape[0]# test数据个数
        Ypred = np.zeros((num_test, len(self.k)))# 初始化预测结果
        
        for i in range(num_test):# 每次迭代一张图片

            distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)# 计算范数一
            temp = np.argsort(distances)# numpy.argsort()函数返回的是数组值从小到大的索引值
            
            index = []
            for j in range(len(self.k)):
                # 取出前k个下标
                id = temp[0:self.k[j]]
                # 取出k个下标所对应的类型
                temp_y = np.array(self.ytr)[id]
            
                # 取出k个下标中出现次数最多的label
                index.append(Counter(temp_y).most_common(1)[0][0])

            print(np.array(index), i)

            Ypred[i] = np.array(index)# 记录第i张图片在进行knn之后的label

        return Ypred
  1. runKNN.py
# runKNN.py
# 导入所需要的库
# pickle用于解压数据
# matplotlib用于绘图
# numpy用于矩阵运算
# KNN用于预测label
import pickle
from matplotlib import pyplot as plt
import numpy as np
from KNN import KNearestNeighbor

# 数据地址
filename1 = 'D:/Download/cifar-10-batches-py/data_batch_1'
filename2 = 'D:/Download/cifar-10-batches-py/data_batch_2'
filename3 = 'D:/Download/cifar-10-batches-py/data_batch_3'
filename4 = 'D:/Download/cifar-10-batches-py/data_batch_4'
filename5 = 'D:/Download/cifar-10-batches-py/data_batch_5'
filename_test = 'D:/Download/cifar-10-batches-py/test_batch'

# 定义导入数据的函数
def load_file(filename):
    with open(filename, 'rb') as fo:
        data = pickle.load(fo, encoding='latin1')
    return data

# 使用pickle暴力导入数据  
data = []
data.append(load_file(filename1))
data.append(load_file(filename2))
data.append(load_file(filename3))
data.append(load_file(filename4))
data.append(load_file(filename5))
test_batch = load_file(filename_test)

# 作业要求的K值
k = [1,3,5,7,9]

# 初始化几个会用到的list
result = []
validation = []

# 建立模型
net = KNearestNeighbor(k)

# 5个batch分别迭代
for i in range(5):
    # 训练
    net.train(data[i]['data'], data[i]['labels'])
    # 预测
    result.append(net.predict(test_batch['data']))
    # 计算预测结果与实际label之间的误差
    temp = result[i] - np.array(test_batch['labels']).reshape((np.array(test_batch['labels']).shape[0],1))
    temp[temp != 0] = 1
    # 计算准确度
    validation.append(1 - (sum(abs(temp)) / test_batch['data'].shape[0]))

    print('batch', i, '', validation[i])

print(validation)

# 绘图
plt.title("Cross Validation") 
plt.xlabel("k") 
plt.ylabel("validation") 
plt.axis([0, 10, 0.2, 0.26])
validation = np.matrix(validation).T
ave = np.sum(validation, axis = 1) / len(k)
plt.plot(k, ave)
fig = plt.plot(k, validation, 'ro')
plt.setp(fig, color='b')
plt.savefig('fig')
plt.show()
  1. KNN结果
    可以看到当k=7的时候,准确度最高(其实也高不到哪儿去)


    fig.png
上一篇下一篇

猜你喜欢

热点阅读