机器学习

决策树

2019-03-05  本文已影响0人  CSTDOG

决策树定义

Function createBranch():
检测数据集中的每个子集是否属于同一分类
    If so return 类标签;
    Else
        寻找划分数据集的最好特征
        划分数据集
        创建分支节点
            for 每个划分的子集
                调用函数createBranch()并增加返回结果到分支节点中
        return 分支节点
H=-\sum_{i=1}^{n}p(x_i)log_2p(x_i)

决策树的程序实现

from math import log
def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts ={}
    # 求出各个类别的样本总数,用来计算p(xi)
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannoEnt = 0.0
    #利用公式计算熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannoEnt -= prob*log(prob,2)
    return shannoEnt
#按照给定特征划分数据集
#extend() 函数用于在列表末尾一次性追加另一个序列中的多个值
#输入:待划分的数据集,划分数据集的特征,需要返回的特征的值
#a[i:j]:表示取从第i+1位到第j个元素
def splitDataSet(dataSet,axis,value):
    retDataSet =[]
    for featVec in dataSet:
        if featVec[axis] == value:
            # 从数据元组去掉该特征值
            reduceFeatVec =featVec[:axis]
            reduceFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet
#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    #计算列数得到特征数
    numFeatures = len(dataSet) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeatures = -1
    for i in range(numFeatures):
        #循环取出dataset的元组,然后取出元组中的第i列的所有取值
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        #计算每个特征的所有特征值作为划分子节点时的熵
        for values in uniqueVals:
            subDataSet =splitDataSet(dataSet, i ,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob*calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        #找到最好的特征值的,增益值越大分类效果越好
        if(infoGain>bestInfoGain):
            bestInfoGain = infoGain
            bestFeatures = i
    return bestFeatures
#当出现多个类标签结果时,采用投票表决的方法决定最终类标签
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote]=0
        classCount[vote] += 1
        sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(i),reverse=True)
    return sortedClassCount[0][0]
#创建树的函数代码
def createTree(dataSet, labels):
    classList=[example[-1]  for example in dataSet]
    #观察当前分组的分类标签是否一致
    if(classList.count(classList[0])==len(classList)):
        return classList[0]
    #特征已经遍历完
    if(len(dataSet[0]) == 1):
        return majorityCnt(classList)
    bestFeat =chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    #得到该特征所有可能的特征值,然后进行分类
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels =labels[:]
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree
#使用决策树执行分类
#输入:决策树,特征表,测试数据
def classify(inputTree, feaLabels, testVec):
    #取出第一个键值,即根
    firstStr = inputTree.keys()[0]
    #获得根对应的子节点
    secondDict = inputTree[firstStr]
    #找到第一个键值在实际数据存储中的列位置
    featIndex = feaLabels.index(firstStr)
    for key in secondDict.keys():
        #找到下一步要到达的节点
        if testVec[featIndex]==key:
            if type(secondDict[key])._name_== 'dict':
                classLabel=classify(secondDict[key],feaLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabel
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'w')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)
上一篇 下一篇

猜你喜欢

热点阅读