玩转手写字体识别,进一步了解KNN算法
上一节我们通过鸢尾花的案例带着大家了解了KNN算法,让大家对机器学习最简单的KNN算法有个大概的印象。
这一次我们从另外一个数据集——手写字体识别,进一步深入了解KNN算法。
第一步,导入计算包以及原始数据。
这个原始数据digits.csv.gz在sklearn的datasets里面,安装好sklearn后直接load即可。代码如下:
# 导包并load digit原始数据
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
from sklearn.model_selection import cross_val_score
X = datasets.load_digits()
print(X.data) #打印查看数据特征
print(X.target) #打印查看数据对应的标签
print('*' * 20)
print(X.data.shape)
print(X.target.shape)
运行结果如下:
1.JPG
可以看到共有1797个样本,每个样本为64维数据,其标签是0-9中的某个数字。这64维是从一个8*8的二维数组拉平产生的。我们可以用plt命令显示查看一下,看看是否和标签值一致
# 将X.data使用图片显示,同时打印出target值,看看是否能够对应上
plt.imshow(X.data[102].reshape(8,8)) # 需要将数据reshape成一个8*8的二维数据
print('This image is: {} '.format(X.target[102]))
从data中任取一个样本,并将其reshape还原成8*8的图片格式,在用plt.imshow显示。只后在打印出它的标签看看是否可以对应起来。结果如下,可以看到图片和结果一致。
2.JPG第二步, 我们来分析KNN算法中K值的选择对程序运行时间的影响
首先,我们知道KNN算法的一个关键参数是K,即 neighbors邻居的数量,因而我们是否可以想一想,如果给出不同的K值,即需要寻找的最近的邻居数量不一样,是否算法花费的时间也不一样?稍微思考一下,是否是要寻找的邻居越多,程序花费的时间越多?是否是这样,我们来验证一下。
我们引入time模块统计代码的运行时间,通过统计不同的K值来比较一下就可以得到答案。
from time import time
def knn_time(k):
start = time()
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(X.data,X.target,test_size=0.2)
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(x_train, y_train)
y_ = knn.predict(x_test)
stop = time()
return stop-start
# 统计K从1-20以内的程序运行时间
x = range(1,20)
k = []
for i in x:
t_ = knn_time(int(i))
k.append(t_)
from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei'] #中文字体显示
plt.figure(figsize= (15,6))
plt.plot(x,k)
plt.xlabel('k value of KNN')
plt.ylabel('time')
plt.xticks(x)
plt.title('不同的K值下,算法的运行时间', fontsize = 15)
运行结果如下:
image.png
可以看到算法的花费时间并没有随着 K值的增大而增加,这是什么原因呢?这就是今天我想说的KNN算法的第一个特点,KNN属于惰性学习(lazy-learning),意思是KNN算法没有显式的学习过程!也就是说没有训练阶段,虽然我们的代码中有knn.fit(),但其实算法并没有进行训练。等到我们给算法一个新的待预测的样本时,算法直接计算该样本与训练样本空间中的所有样本的距离。算法的复杂度和样本数量n成正比,当以欧式距离来计算距离时算法的时间复杂度为O(nNN)。也就是说不论K取多少,算法都需要计算待预测样本与所有训练样本的距离,只后在排序,按照K值来切片前K个样本。对于KNN算法,大部分的时间花费在前期的距离计算与排序上。因为K值的选取并不显著的影响运行时间。
这也从另外一个侧面说明了KNN算法不太适合高维度的数据分类任务。因为维度越高,计算距离越困难。
最后,我们来总结一下KNN算法每预测一个“点”的分类都会重新进行一次全局运算,对于样本容量大的数据集计算量比较大。因而容易导致维度灾难,在高维空间中计算距离的时候,就会变得非常远;样本不平衡时,预测偏差比较大,k值大小的选择得依靠经验或者交叉验证得到。k的选择可以使用交叉验证,也可以使用网格搜索。k的值越大,模型的偏差越大,对噪声数据越不敏感,当 k的值很大的时候,可能造成模型欠拟合。k的值越小,模型的方差就会越大,当 k的值很小的时候,就会造成模型的过拟合。
以上就是本次的所有内容。加上前两期的内容,我们通过三个入门的案例,带着大家熟悉机器学习的核心内容:通过学习数据内部的关联性来对新的样本进行预测。归根到底,机器学习是用统计学加概率创建一个数学模型来处理数据。因为大部分的数据都是高维度的数据,这就需要用到线性代数的相关知识。这些知识我们会在以后的分享中详细说明。