决策树代码

2019-08-03  本文已影响0人  wensong_kevin
#导入包
from math import log
from operator import itemgetter
from graphviz import Digraph

# 构造数据集
def createDataset():
    dataset = [
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]
    label = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']
    return dataset,label

# 计算香农熵
def calcShannonEnt(dataset,feat):
    numEntropies = len(dataset) #数据总量
    labelCounts = {}   #存放特征及对应数量
    for featval in dataset: #计算每种类别的总数
        feature =  featval[feat]    #feature是每行中某一个的特征名称
        if feature not in labelCounts.keys():  #如果特征集中没有这个新特征,就添加进来,并且赋初值为0
            labelCounts[feature] = 0
        labelCounts[feature]+=1            #对应的特征数 总量加1
    entropy = 0
    for key in labelCounts:
        prob = labelCounts[key]/numEntropies   #计算每个特征的概率
        entropy -= prob*(log(prob,2))      # H = sum(p(xi)*log2(p(xi)))
    return entropy

#按照某个特征划分数据集
def splitDtatset(dataset,index,value):
    split_result = []   #存放按照特征划分之后的数据集
    for data in dataset:
        if data[index]==value:
            split_result.append(data[:index]+data[index+1:])
    return split_result

#选择最好的特征(即信息增益最大的特征)
def chooseBestFeature(dataset):
    numFeature = len(dataset[0])-1               #特征个数,最后一列是分类结果,所以删掉
    baseEntropy = calcShannonEnt(dataset,-1)     #数据集的经验熵
    bestFeat = -1
    bestinfoGain = 0
    bestGainRate = 0
    for i in range(numFeature):                  #遍历所有的特征
        featlist = [rowdata[i] for rowdata in dataset]   #存放着所有的特征
        uniquevals = set(featlist)               # 对特征集进行去重操作
        newEntropy = 0
        selfEntropy = -1  #自身经验熵
        for value in uniquevals:
            subDataset = splitDtatset(dataset,i,value)
            prob = len(subDataset)/len(dataset)
            newEntropy = prob*calcShannonEnt(subDataset,i)
            selfEntropy = calcShannonEnt(subDataset, i)
        infoGain = baseEntropy-newEntropy
        ##   ID3
        if infoGain>bestinfoGain:
            bestinfoGain = infoGain
            bestFeat = i

        ##   C4.5
        # if selfEntropy==0:
        #     continue
        # GainRate = infoGain/selfEntropy  #增益比
        # if GainRate>bestGainRate:
        #     bestGainRate = GainRate
        #     bestFeat = i

        return bestFeat

#按照分类后各个特征的信息增益进行排序
def majorityEnt(classlist):
    c_count = {}
    for i in classlist:
        if i not in c_count.keys():
            c_count[i]=0
        c_count[i] += 1
    classout = sorted(c_count.items(),key=itemgetter(1),reverse=True)
    print(classout[0][0])
    return classout[0][0]           #一维是元素二维是对应的元素个数

#递归构建决策树
def createTree(dataset,labels):
    classlist = [rowdata[-1] for rowdata in dataset]
    if classlist.count(classlist[0]) == len(classlist):
        return classlist[0]
    if len(dataset[0])==1:
        return majorityEnt(classlist)
    bestFeat = chooseBestFeature(dataset)
    bestLab = labels[bestFeat]
    mytree = {bestLab:{}}
    del(labels[bestFeat])
    featvalues = [rowdata[bestFeat] for rowdata in dataset]
    uniquelvalues = set(featvalues)
    for value in uniquelvalues:
        subLabels = labels[:]
        mytree[bestLab][value]=createTree(splitDtatset(dataset,bestFeat,value),subLabels)
    return mytree

#可视化展示
def plot_model(tree, name):
    g = Digraph("G", filename=name, format='png', strict=False,encoding="utf-8")
    first_label = list(tree.keys())[0]
    g.node("0", first_label,fontname="Kaiti")
    _sub_plot(g, tree, "0")
    g.view()
root = "0"
def _sub_plot(g, tree, inc):
    global root
    first_label = list(tree.keys())[0]
    ts = tree[first_label]
    for i in ts.keys():
        if isinstance(tree[first_label][i], dict):
            root = str(int(root) + 1)
            g.node(root, list(tree[first_label][i].keys())[0],fontname="Kaiti")
            g.edge(inc, root, str(i),fontname="Kaiti")
            _sub_plot(g, tree[first_label][i], root)
        else:
            root = str(int(root) + 1)
            g.node(root, tree[first_label][i],fontname="Kaiti")
            g.edge(inc, root, str(i),fontname="Kaiti")

#主函数
if __name__=="__main__":
    dataset,labels = createDataset()
    tree = createTree(dataset,labels)
    plot_model(tree,"decision_tree.gv")
    print(tree)
上一篇下一篇

猜你喜欢

热点阅读