K-近邻分类算法解析实践

2019-05-13  本文已影响0人  qiufeng1ye

教材选用《统计学习方法》,第一版,李航著;
代码取自《机器学习实战》,人民邮电出版社;


K-近邻算法(K-NearestNeighbours)是一种最简单的基本分类方法,它的定义为:在给定训练集中,找到与新的输入实例最相近的K个实例,新实例的类型就被划分为这K个实例多数属于的类型。

K-近邻算法的模型实际对应特征空间的划分,模型由距离度量、K值的选择和分类决策规则所决定。
距离度量一般使用欧氏距离,不同的距离度量所确定的最近邻点是不同的。距离度量相关的资料点这里。
K值的选择会对模型结果产生重大影响,在应用中K一般选一个较小的值,再采用交叉验证法找出最优的K值。
分类决策规则一般选用多数表决,多数表决规则等价于经验风险最小化。

K-近邻算法最简单的实现方法是线性扫描,但考虑到效率问题采用kd树实现可以优化速度,kd树更适用于训练实例远大于空间维数时的K-近邻搜索。


K-近邻算法的特点
K-近邻算法的一般流程

以下为K-近邻算法的简单实现例子,首先将以下4个数据点分为蓝和红两类,然后通过K-近邻算法找出新输入的点属于哪一类。

1.使用Python导入数据

首先导入科学计算包numpy和operator模块,然后建立数据集,样本为4个点,前两个设为红点,后两个设为蓝点。(运行环境为Python3.6)

from numpy import *
import operator

def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 
    labels = ['r','r','b','b']
    return group, labels

为了更清晰地看出数据之间的关系,通过数据可视化包matplotlib绘制出散点图。散点图的绘制教程点这里。

import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(group[:,0],group[:,1],15,labels)
plt.show()
4个数据点的散点图

2.K-近邻算法的实现

接下来定义分类算法classify0(),下面是Python代码,然后会详细解释每行代码的含义。

from os import listdir

def classify0(inX, dataSet, labels, k): 
    #1 计算欧氏距离,见下方计算公式
    dataSetSize = dataSet.shape[0] #获取数据形状
    diffMat = tile(inX, (dataSetSize,1)) - dataSet #用tile()重复计算新输入分类点和样本点之前的差值数组
    sqDiffMat = diffMat**2 #将差值数组平方
    sqDistances = sqDiffMat.sum(axis=1) #将平方后的数组累加
    distances = sqDistances**0.5 #开方得欧氏距离distances 
    sortedDistIndicies = distances.argsort() #argsort()按从小到大顺序排序所有点距离
    #2 得到主要分类
    classCount={}   #新建classCount数组
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]  #取出数组标签
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #按前k个投票给点,得到主要标签
    #3 返回频率最高的分类
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

classify0有四个输入:inX为新输入的分类点,dataSet为样本训练集(4个数据点),labels为分类标签(蓝和红),k为选择近邻的参数。

1.按照公式计算新输入点和所有样本点之间的欧氏距离;


欧氏距离计算公式

2.按照从小到大的次序进行排序,确定前k个距离最小元素的主要分类;

3.将classcount分解为元组列表,按照第二个元素从大到小排序,返回频率最高的元素标签。

3.测试分类器classify0

为了测试数据所有分类,在命令行中输入如下命令:

print(classify0([0.8,0.8], group, labels, 3))  #测试用了(0.8,0.8)这个点,k值用了3

在调试模式下运行命令,查看变量的属性。dataSetSize 为4,代表样本4个点;diffMat 显示了前两个点的差值( -0.2 = 0.8 - 1)。


调试中的变量截图1

sqDiffMat 为 diffMat数组的平方; sqDiffMat 为 sqDiffMat 数组中的平方值相加;distances 为sqDiffMat 数组的开方;sortedDistIndicies 为distances 进行了排序,并取索引为新数组。


调试中的变量截图2
调试中的变量截图3
调试中的变量截图4

classcount统计出样本空间里前k=3个距离最近的点,红色的为2个,蓝色的1个;sortedclasscount为排序后的标签数组,因此预测新输入的点属于红色类别。


调试中的变量截图5
调试中的变量截图6 因此,最后按照分类器得出的结果是:r为红点。 测试结果 可视化结果
上一篇 下一篇

猜你喜欢

热点阅读