小白学习深度学习之KNN
今天我们要说的是:
KNN,也叫作K近邻,他是一个常用的分类模型,他的设计思想要比之前的容易理解一些,但是他也有让人头大的地方,就是kd-tree,不过那都是后话了
例子
今天我们先看一个例子:
有一个老师,在想要在所有同学里面找到一个“老师眼里的小内裤,同学心中的小混蛋”来作为自己的耳目,但是苦于刚刚开学谁也不认识谁,这个工作就变得艰难起来,怎么才能最快找到这么一个小混蛋呢?老师想到一个办法,去看他身边的几个朋友,如果都有成为小混蛋的趋势,那么,这个八成也是个小混蛋了。
当然,这个例子并不是那么贴切,因为你怎么知道别人有成为小混蛋的趋势。但是,重点不在这里,家长经常和孩子说,要去找学习好的孩子一起玩;也有俗语说:近朱者赤近墨者黑。
这其实都表明了一个道理:
一个人的好坏可以通过周围人的好坏判断出来。
同理,一个特征向量所表征的特征也可以通过特征空间里与他相近的几个点表示出来。
什么是KNN
knn,k nearest neighbor,是一个常用的分类器(或用于聚类),我们通过选取K个邻近的点,来判断某个点的类别。但是怎么判断呢?
分类决策规则
knn使用了投票表决的方式进行分类,就像他的名字,“少数服从多数”,也就是说,K个邻近的点中,占比最多的类别就是该点的类别。
K值的选择
k值的选择十分关键:
如果你的K值取得过大,就会出现不相干的点也被算到了里面,从而学习不到特征,就是欠拟合。
如果你的K值取得特别的小,那么模型就会对细节特别的敏感,而且周围一旦出现噪声,就会使模型变得特别差,也就是过拟合。
距离量度
如何表征两个特征向量之间的距离呢?
而两个特征向量的距离又表示在什么呢?
如何表征两个特征向量之间的距离?
我们在这里简单的介绍一下常见的距离。
Lp距离,也叫作:Lp范数,我理解他就是对于向量长度的某种表征
Lp范数
当然,p是一个常数,一般来说,常被用到的有:
L1范数
L1范数其实就是向量各个分量的差的绝对值之和,又叫做曼哈顿距离,据说这个名字来源于曼哈顿的出租车司机在曼哈顿区沿着正方形的路径行驶得来的。
L2范数,也叫欧几里得距离,想必这个大家都熟悉,就不冗述了。
无穷范数
无穷范数,就是求两个特征向量中相差最大的维度。
需要注意的是:在不同距离量度下,点与点之间的距离是不同的
而在knn中,我们用的是L2范数,也就是欧几里得距离
KD-tree
根据前面说的,其实我们的KNN已经做好,那kd-tree是用来干嘛的呢?我们来思考这样的一个问题:
我们一共有五个点作为训练数据,然后又给了一个点作为测试数据,现在,我们想要的到测试点的类别,请问我们一共进行了多少次运算?
首先,我们算出了测试点到五个点的距离,总共5次运算,
然后,我们还需要对距离排序,使用快排的话,时间复杂度是O(nlogn)
那么我们再来看,现在给你1000000个点,请问你要计算多少次?
所以,问题就摆在我们面前了,如何有效的减少计算次数就是我们要进行优化的。
A K-D Tree(also called as K-Dimensional Tree) is a binary search tree where data in each node is a K-Dimensional point in space. In short, it is a space partitioning(details below) data structure for organizing points in a K-Dimensional space.
上面是来自GeeksForGeeks的解答
就是说什么呢?
kd树,就是K Dimensional Tree,是一个用来划分空间的特殊二叉树。具体如何搭建比较简单,网上也有很多教程,在这里我想说一些不一样的东西。
需要注意的是:这里的K不是Knn里面的K,而是说特征空间的维数。
为什么kd树可以减少运算次数
KD树以二叉树的形式,将每一个特征向量表示,然后,我们就可以利用某些剪枝算法,对树进行剪枝,然后我们就可以减少运算次数,这样说有些抽象,我们可以看一个例子。
例子:
我们假设要在这6个点中找到离(3,2)最近的点,我们需要怎么做呢?
很简单,根据切割超平面不断向下走,直到叶子节点(2,3),我们管这个叶子节点叫做“当前最近点”,但他其实并不是最近的点,只是我们最初估计的。
然后呢?就开始回溯,我们知道,如果(2,3)不是最近的点,那么,最近的点一定在以两点连线为半径r的圆里面,那么,我们就回溯到父节点,计算(3,2)到父节点所在切割超平面的距离d,如果d > r,那么说明超平面切割不到超矩形区域,与这个圆是没有交点的,也就是说,实际上,兄弟节点的那个点是无论如何都不可能成为最近点的,所以在下一次时,我们就直接回溯到父节点的父节点。这里,我们看到,实际上就是一次剪枝操作,兄弟节点被跳过了,我们三次计算变成了两次计算,可以想象如果数据多了的话,会节省多少时间。
kd树在二维空间的展开
代码:
import numpy as np
import matplotlib.pyplot as plt
import math
from stack import stack
class knn:
def __init__(self):
self.data = self.data_produce()
self.tree = self.create_branch(self.data, 0)
def data_produce(self, num=6):
x = np.random.randint(0, 50, num)
y = np.random.randint(0, 50, num)
data = []
for i in range(num):
tmp = [x[i], y[i]]
data.append(tmp)
return data
def draw_figure(self):
x = []
y = []
for i in self.data:
x.append(i[0])
y.append(i[1])
plt.scatter(x, y)
plt.show()
def create_branch(self, data, axia):
tree_branch = {}
if len(data) != 0 and len(data) != 1:
sorted_data = self.sort(data, axia)
if axia == 0:
axia = 1
else:
axia = 0
length = len(data)
index = length // 2
tree_branch['data'] = sorted_data[index]
tree_branch['left'] = self.create_branch(sorted_data[:index], axia)
tree_branch['right'] = self.create_branch(sorted_data[index+1:], axia)
elif len(data) == 1:
tree_branch['data'] = data[0]
return tree_branch
def search(self, point, axia):
"""
利用kd-tree进行最近邻搜索
思路:
首先,我们通过切分超平面找到近似最近点,
然后,我们去切分超平面的另一侧去看看有没有更近的点(通过到切分超平面的距离与到近似最近点的距离的比较)
一直回溯上去
"""
path_stack = stack()
def find_fake_near(tree, point, axia):
if 'left' not in tree.keys() and 'right' not in tree.keys():
return tree
else:
if point[axia] < tree['data'][axia]:
if axia == 0:
axia =1
else:
axia = 0
path_stack.push(tree)
return find_fake_near(tree['left'], point, axia)
else:
if axia == 0:
axia = 1
else:
axia = 0
path_stack.push(tree)
return find_fake_near(tree['right'], point, axia)
tree = self.tree
"""
initial three points
"""
cur_nearest_poi = find_fake_near(tree, point, 0)['data']
target_poi = point
doubt_poi = path_stack.pop()
axia = path_stack.length() % 2
def backtrace(cur_nearest_poi, target_poi, doubt_poi):
nonlocal axia
d1 = cal_distance(target_poi, cur_nearest_poi)
d2 = abs(doubt_poi['data'][axia] - target_poi[axia])
cnp_father = doubt_poi
if d1 > d2:
tmp = cur_nearest_poi
cur_nearest_poi = doubt_poi['data']
if 'left' in cnp_father.keys() and 'right' in cnp_father.keys():
if cnp_father['left'] == tmp:
return backtrace(cur_nearest_poi, target_poi, cnp_father['right'] )
else:
return backtrace(cur_nearest_poi, target_poi, cnp_father['left'] )
else:
if path_stack.length() != 0:
cnp_grand = path_stack.pop()
axia = path_stack.length() % 2
return backtrace(cur_nearest_poi, target_poi, cnp_grand)
else:
return cur_nearest_poi
elif d1 < d2 and path_stack.length() != 0:
doubt_poi = path_stack.pop()
return backtrace(cur_nearest_poi, target_poi, doubt_poi)
else:
return cur_nearest_poi
def cal_distance(poi1, poi2):
return math.sqrt(pow((poi1[0] - poi2[0]), 2) + pow((poi1[1] - poi2[1]), 2))
return backtrace(cur_nearest_poi, target_poi, doubt_poi)
def sort(self, data, axia):
if len(data) < 2:
return data
else:
pivot = data[0]
less = [i for i in data[1:] if i[axia] <= pivot[axia]]
greater = [i for i in data[1:] if i[axia] > pivot[axia]]
return self.sort(less, axia) + [pivot] + self.sort(greater, axia)
if __name__ == "__main__":
k = knn()
Nearest = k.search([20, 5], 0)
x = [i[0] for i in k.data if i != Nearest]
y = [j[1] for j in k.data if j != Nearest]
print(Nearest)
plt.scatter(x,y)
plt.scatter(20, 5)
plt.scatter(Nearest[0],Nearest[1])
plt.show()
在这里,说一下我写的代码:
1、首先,他还是有一些小bug的,但是大体是没错的,也欢迎大家帮我挑错。
2、在构建KD树时,我使用了栈去维护我的树,当然,手工实现了一个简单的栈(没有技术含量,下面贴代码)
3、在程序搭建里,大量使用了递归,尽可能的使用了尾递归优化,避免出现stack overflow,同时,不得不说,在处理树这个结构时,递归的思想既简单,又方便设计。而且涉及到需要回溯的问题,栈也是我们的不二选择。
下面是栈的简单实现:
class stack:
def __init__(self):
self.s_list = []
self._length = 0
self.top = None
self.bottom = None
def pop(self):
return self.s_list.pop()
def push(self, data):
self.s_list.append(data)
self.top = data
def is_empty(self):
if len(self.s_list) == 0:
return True
else:
return False
def length(self):
return len(self.s_list)
程序运行结果
总的来说,没有显式的学习过程,knn的设计思想还是比较简单的。
以上为我的个人看法,希望大家指出我的错误之处。
更新
晚上和小伙伴讨论了一下这个KD树的搭建,现在有一些新的想法记录一下。
对于回溯算法,开始我不知道怎么实现于是使用了一个路径栈来记录来的时候的路径,但是在在和小伙伴聊完以后,我又去查了一下回溯算法,这里简单的说一下:
回溯算法分为两种:
1、递归调用方法
2、非递归调用方法
我想先说一下第二个:
我在实现KD树的树结构时,采用的是python的dict作为基本结构,这样就导致我用非递归的方法回溯很困难,后来我想,如果我用list实现树的结构,应该就会简单不少,因为list可以使用index访问(也不知道为啥我对dict情有独钟,可能是哈希表看起来洋气吧)。
至于递归的回溯方法,大家先看一下别人是怎么介绍的;
回溯算法实际上一个类似枚举的搜索尝试过程,主要是在搜索尝试过程中寻找问题的解,当发现已不满足求解条件时,就“回溯”返回,尝试别的路径。
回溯法是一种选优搜索法,按选优条件向前搜索,以达到目标。但当探索到某一步时,发现原先选择并不优或达不到目标,就退回一步重新选择,这种走不通就退回再走的技术为回溯法,而满足回溯条件的某个状态的点称为“回溯点”。
不知道大家看完有没有感觉,就我而言,我想起了一个算法:
深度优先探索
递归到叶子节点,然后剪枝,最后回溯。
和KD树的搭建思想一模一样,之前还是太天真了。
想看更多有关回溯算法以及代码实现,请继续关注我的更新。