决策树实现

2020-07-26  本文已影响0人  墩er

基于ID3算法的信息增益来实现

#!/usr/bin/python
#encoding:utf-8

from math import log
import operator
import treePlotter
import sys
reload(sys)
sys.setdefaultencoding("utf-8")

def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['no surfacing','flippers']
    #change to discrete values
    return dataSet, labels

##### 计算信息熵 ######
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  # 样本数
    labelCounts = {}   # 创建一个数据字典:key是最后一列的数值(即标签,也就是目标分类的类别),value是属于该类别的样本个数
    for featVec in dataSet: # 遍历整个数据集,每次取一行
        currentLabel = featVec[-1]  #取该行最后一列的值
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0  # 初始化信息熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2) #log base 2  计算信息熵
    return shannonEnt

##### 按给定的特征划分数据 #########
def splitDataSet(dataSet, axis, value): #axis是dataSet数据集下要进行特征划分的列号例如outlook是0列,value是该列下某个特征值,0列中的sunny
    retDataSet = []
    for featVec in dataSet: #遍历数据集,并抽取按axis的当前value特征进划分的数据集(不包括axis列的值)
        if featVec[axis] == value: # 特征值==value的样本
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis+1:]) #不包括axis列的值
            retDataSet.append(reducedFeatVec)
            # print axis,value,reducedFeatVec
    # print retDataSet
    return retDataSet

##### 选取当前数据集下,用于划分数据集的最优特征
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #获取当前数据集的特征个数,最后一列是分类标签
    baseEntropy = calcShannonEnt(dataSet)  #计算当前数据集的信息熵
    bestInfoGain = 0.0; bestFeature = -1   #初始化最优信息增益和最优的特征
    for i in range(numFeatures):        #遍历每个特征iterate over all the features
        featList = [example[i] for example in dataSet]#获取数据集中当前特征下的所有值
        uniqueVals = set(featList)       # 获取当前特征值,例如outlook下有sunny、overcast、rainy
        newEntropy = 0.0
        for value in uniqueVals: #计算每种划分方式的信息熵
            subDataSet = splitDataSet(dataSet, i, value) # 特征值==value的样本集合
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy     #计算信息增益
        if (infoGain > bestInfoGain):       #比较每个特征的信息增益,只要最好的信息增益
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = I
    return bestFeature                      #returns an integer

#####该函数使用分类名称的列表,然后创建键值为classList中唯一值的数据字典。字典
#####对象的存储了classList中每个类标签出现的频率。最后利用operator操作键值排序字典,
#####并返回出现次数最多的分类名称
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

##### 生成决策树主方法
def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet] # 返回当前数据集下标签列所有值
    if classList.count(classList[0]) == len(classList):
        return classList[0]#当类别完全相同时则停止继续划分,直接返回该类的标签
    if len(dataSet[0]) == 1:  # 样本只有一个特征
    ##遍历完所有的特征时,仍然不能将数据集划分成仅包含唯一类别的分组 dataSet
        return majorityCnt(classList) #由于无法简单的返回唯一的类标签,这里就返回出现次数最多的类别作为返回值
    bestFeat = chooseBestFeatureToSplit(dataSet) # 获取最好的分类特征索引
    bestFeatLabel = labels[bestFeat] #获取该特征的名字

    # 创建树
    # 这里直接使用字典变量来存储树信息,这对于绘制树形图很重要。
    myTree = {bestFeatLabel:{}} #当前数据集选取最好的特征存储在bestFeat中
    del(labels[bestFeat]) #删除已经在选取的特征
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree

# 使用决策树执行分类
def classify(inputTree,featLabels,testVec):
    firstStr = inputTree.keys()[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: classLabel = valueOfFeat
    return classLabel

# 使用Pickle模块存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename)
    return pickle.load(fr)


if __name__ == '__main__':
    fr = open('play.tennies.txt')
    lenses =[inst.strip().split(' ') for inst in fr.readlines()]
    lensesLabels = ['outlook','temperature','huminidy','windy']
    lensesTree =createTree(lenses,lensesLabels)
    treePlotter.createPlot(lensesTree)


数据集 .png 决策树.png 创建树.png 决策树分类.png

使用Pickle模块存储决策树

pickle是为了序列化/反序列化一个对象的,可以把一个对象持久化存储。 比如你有一个对象,想下次运行程序的时候直接用,可以直接用pickle打包存到硬盘上。或者你想把一个对象传给网络上的其他程序,可以用pickle打包,然后传过去

参考:
决策树算法及Python实现
https://blog.csdn.net/qq_34807908/article/details/81539536

上一篇下一篇

猜你喜欢

热点阅读