kNN(构造kd树的实现)

2018-08-24  本文已影响0人  Gantowell

k近邻模型


1. 三个基本要素:k值的选择,距离度量,分类决策规则

2. 算法:

3. kd树的构造



python实现

class node():
    def __init__(self, value=None, left=None, right=None, depth=None, space=None):
        self.value = value
        self.left = left
        self.right = right
        self.depth = depth
        self.space = space

    def orderTraverse(self, root):
        def queueForm(root):
            queue = []
            result = []
            if root == None: return result
            queue.append(root)
            while queue:
                newnode = queue.pop(0)
                result.append(newnode.value)
                if newnode.left!=None:
                    queue.append(newnode.left)
                if newnode.right!=None:
                    queue.append(newnode.right)
            return result
        print(queueForm(root))


class kdtree():

    def __init__(self):
        pass

      #生成树
      #depth为深度,root深度为0      
      #space为在该节点创建之前的需要划分的超矩形区域
      #curFeature为按照第几维特征来划分
      #curMedian为在此特征下的中位数index

    def generateNode(self, curNode):
        space = curNode.space
        if space.shape[0] == 1:
            curNode.value = space[0]
            return
        curFeature = curNode.depth % self.K
        curSet = curNode.space
        curSet = curSet[curSet[:,curFeature].argsort()]
        curMedian = int(curSet.shape[0]/2)
        curNode.value = curSet[curMedian]
        if curSet[curMedian+1:].shape[0] > 0:
            curNode.right = node(depth=curNode.depth + 1, space=curSet[curMedian+1:])
            self.generateNode(curNode.right)
        if curSet[0:curMedian].shape[0] > 0:
            curNode.left = node(depth=curNode.depth + 1, space=curSet[0:curMedian])
            self.generateNode(curNode.left)

      ##构造kd树:
      #root为根
    def constructTree(self, trainset):
        self.K = trainset.shape[1]
        self.root = node(depth=0, space=trainset)
        self.generateNode(self.root)

    def visualize(self):
        self.root.orderTraverse(self.root)

if __name__ == "__main__":
    trainset = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
    trainset = np.array(trainset)
    mytree  = kdtree()
    mytree.constructTree(trainset)
    mytree.visualize()






上一篇 下一篇

猜你喜欢

热点阅读