第三章 kd-tree

2019-11-15  本文已影响0人  sadamu0912
import numpy as np
from operator import itemgetter
# 获取两个特征点的二范数,也叫欧式距离
def _get_euclidean_distance(feature1,feature2):
    return ((feature1-feature2) ** 2).sum() ** 0.5
class Node(object):
    #把args打包成tuple使用,
    #比如函数类型 function(x, y, *args),参数: 1,2,3,4,5,会变成1,2,(3,4,5)
    #  **kwargs 打包参数成dict使用。
    #比如函数类型 function(**kwargs) function(a=1, b=2, c=3),会变成 {'a':1,'b':2,'c':3}
    #写成这样子的好处是一个构造函数,可以当多个构造函数使用
    def __init__(self,*args,**kwargs):
        self.father=kwargs.get("father")
        self.left=kwargs.get("left")
        self.right=kwargs.get("right")
        #节点的特征
        self.feature=kwargs.get("feature")
        #节点的分隔维度
        self.split=kwargs.get("split")
    #判断两个节点是否相等  ,待优化
    def equals(self,node):
        if(self.feature==node.feature).all():
            return True
        else:
            return False
class KDTree(object):
    def __init__(self):
        self.root =Node()
     #构造kd树的时候,选择方差大的维度,进行分割,使得kd树数据尽量分布均匀
    def _choose_feature(self,X,idxs):
        m=len(X[0])
        variances=map(lambda j:(j,self._get_variance(X,idxs,j)),range(m))
        return max(variances,key=lambda x:x[1])[0]
     #X是样本集,是个二维数组,idxs是要求方差的样本点的索引,一维数组,dimension是对哪个维度进行 
     #方差计算
    def _get_variance(self,X,idxs,dimension):
        n=len(idxs)
        col_sum=col_sum_sqrt=0
        for idx in idxs:
            xi=X[idx][dimension]
            col_sum+=xi
            col_sum_sqrt+=xi ** 2
            ## D(X)=E(X^2)-[E(X)]^2
            return col_sum_sqrt/n -(col_sum/n) ** 2
    #计算输入点point,到nd所在的分离超平面的距离
    def _get_hyper_plane_dist(self,point,nd):
        #节点的分割维度
         j=nd.split
        #分割维度上的差到超平面的距离
        return abs(point[j]-nd.feature[j])
    def search_node(self,point,node):
        while node!=None and (node.left or node.right):
            if(point.feature[node.split]<node.feature[node.split]):
                node=node.left
            elif point.feature[node.split]>node.feature[node.split]:
                node=node.right
            else:
                if(node.equals(point)):
                    break;
        return node
    #记录节点的查找到叶子节点的搜索路径,以便后面的回溯
    def get_search_path(self,point,node):
        search_path=[node.feature]
        while node!=Node and (node.left or node.right):
            if point[node.split]<node.feature[node.split]:
                node=node.left
                if node!=None:
                    search_path.append(node.feature)
            else:
                node=node.right
                if node!=None:
                    search_path.append(node.feature)
        return search_path
    #记录搜索路径的向左向右的集合,也算是记录路径的一种
    def get_search_directions(self,point,node):
        search_directions=[]
        while node!=Node and (node.left or node.right):
            if point[node.split]<node.feature[node.split]:
                node=node.left
                if node!=None:
                    search_directions.append(1)
            else:
                node=node.right
                if node!=None:
                    search_directions.append(0)
        return search_directions
    #获取样本集X,二维数组,在feature维度,的中值的索引
    def _get_median_idx(self,X,idxs,feature):
        n=len(idxs)
        k=n//2
        col=map(lambda i:(i,X[i][feature]),idxs)
        sorted_idxs=map(lambda x:x[0],sorted(col,key=lambda x:x[1]))
        median_idx=list(sorted_idxs)[k]
        return median_idx
    def _split_feature(self,X,idxs,feature,median_idx):
        idxs_split=[[],[]]
        split_val=X[median_idx][feature]
        for idx in idxs:
            if idx== median_idx:
                continue
            xi=X[idx][feature]
            if xi<split_val:
                idxs_split[0].append(idx)
            else:
                idxs_split[1].append(idx)
        return idxs_split
      #递归构建kd树
    def build_tree(self,points,depth):
        if 0 == len(points):
            return None
        cutting_dim= depth % len(points[0])
        medium_index=int(len(points)/2)
        pointsNew=sorted(points,key=itemgetter(cutting_dim))
        node=Node(feature=pointsNew[medium_index],split=cutting_dim)
        node.left=self.build_tree(pointsNew[:medium_index],depth+1)
        node.right=self.build_tree(pointsNew[medium_index+1:],depth+1)
        return node
    #最近邻查找
    def nearest_neighbor_search(self,point,tree):
        search_path=self.get_search_path(point,tree)
        search_directions = self.get_search_directions(point,tree)
        #当前最近节点
        node_best=search_path.pop()
        #回溯节点
        back_node=search_path.pop()
        #回溯节点方向(最后一步的方向)
        back_direction=search_directions.pop()
        #一直循环,直到回溯节点为空
        while all(back_node)!=None and len(search_path):
             #当前里输入点,最近距离
            dis_best=_get_euclidean_distance(node_best,point)
         #输入点,到超平面的距离  
         dis_hyper_plane=self._get_hyper_plane_dist(point,Node(feature=back_node,split=back_direction))
            if dis_best>= dis_hyper_plane:
                node_best=back_node
                dis_best=_get_euclidean_distance(node_best,point)
                #找到回溯节点
                searchedBackNode=self.search_node(Node(feature=back_node,split=back_direction),tree)
                if(back_direction==0):
                    #dis_best>dis_hyper_plane就是输入点到最近点的距离,画圆,和超平面有交点。
                    #说明另外一个子空间,可能有比当前最近点,更近的。 需要回溯到超平面所在的分隔
                    #所在的另外一个子空间,去寻找更近点
                    if searchedBackNode.left:
                        search_path.append(searchedBackNode.left.feature)
                else:
                    if searchedBackNode.right:
                        search_path.append(searchedBackNode.right.feature)
            if len(search_path)>0:
                back_node= search_path.pop()
            if len(search_directions)>0:
                back_direction=search_directions.pop()
        return node_best
tree=KDTree()
features=np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
node=tree.build_tree(features,0)
##print(tree.nearest_neighbor_search([2.1,3.1],node))
print(tree.nearest_neighbor_search([2,4.5],node))

代码待优化,回溯节点的时候,没有考虑周全。

参考文档:
<https://blog.csdn.net/pipisorry/article/details/52186307>

上一篇下一篇

猜你喜欢

热点阅读