决策树底层实现及绘制(python)

2018-05-17  本文已影响0人  吃番茄的土拨鼠
# coding:utf-8
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict

from math import log
import matplotlib.path as mpath
import matplotlib.patches as mpatches
import numpy as np
from matplotlib import font_manager as fm, rcParams


class DecTree:
    def __init__(self):
        pass

    # 计算香浓熵
    def cacuChannoEnt(self, data_set):
        cls_dict = defaultdict(int)
        for data in data_set:
            cls_dict[data[-1]] += 1
        total = len(data_set)
        channo_ent = 0.0
        for num in cls_dict.values():
            p = float(num) / total
            channo_ent -= p * log(p, 2)
        return channo_ent

    def createDataSet(self):
        dataSet = [[0, 0, 0, 0, 'maybe'], [0, 1, 1, 1, 'error'], [1, 1, 1, 0, 'yes'], [1, 0, 0, 1, 'no'],
                   [1, 0, 1, 1, 'no'],
                   [1, 0, 0, 0, 'yes'], [1, 0, 0, 0, 'maybe'],
                   [1, 1, 1, 0, 'yes']]
        labels = ['no surfaceing', 'flippers', 'flippers']
        return dataSet, labels

    def splitDataSet(self, data_set, feat, value):
        '''
        抽取某个feat值为value的数据子集
        :param data_set: 
        :param feat: 
        :param value: 
        :return: 
        '''
        sub_data_set = []
        for data in data_set:
            if data[feat] == value:
                sub_feat = data[:feat]
                sub_feat.extend(data[feat + 1:])
                sub_data_set.append(sub_feat)
        return sub_data_set

    def chooseBestFeature(self, data_set):
        '''
        获取能产生最大信息增益的feat
        :param data_set: 
        :return: 
        '''
        base_ent = self.cacuChannoEnt(data_set)
        feat_num = len(data_set[0]) - 1
        best_feat = 0
        total = len(data_set)
        ent_gain = 0.0
        for i in range(feat_num):
            uni_vals = set([data[i] for data in data_set])
            ent_tmp = 0
            for v in uni_vals:
                sub_data_set = self.splitDataSet(data_set, i, v)
                p = float(len(sub_data_set)) / total
                ent_tmp += p * self.cacuChannoEnt(sub_data_set)
            cur_gain = base_ent - ent_tmp
            if cur_gain > ent_gain:
                best_feat = i
                ent_gain = cur_gain
        return best_feat

    def allCls(self, data_set):
        '''
        计算数据集中全部类别
        :param data_set: 
        :return: 
        '''
        v_list = [data[-1] for data in data_set]
        s = set(v_list)
        return s

    def createTree(self, data_set, feat_list):
        '''
        创建决策树
        :param data_set:数据集 
        :param feat_list: feat集合
        :return: 
        '''
        node = {}
        if len(feat_list) == 0:
            data_cls = [data[-1] for data in data_set]
            return {'cls': self.majorCnt(data_cls)}
        all_cls = self.allCls(data_set)
        if len(all_cls) == 1:
            return {'cls': all_cls.pop()}
        feat = self.chooseBestFeature(data_set)
        uni_vals = set([v[feat] for v in data_set])
        node['feat'] = feat_list[feat]
        node['label'] = {}
        for v in uni_vals:
            sub_dat_set = self.splitDataSet(data_set, feat, v)
            sub_feat_list = feat_list[:feat]
            sub_feat_list.extend(feat_list[feat + 1:])
            child_nd = self.createTree(sub_dat_set, sub_feat_list)
            child_nd['lb'] = v
            node['label'][v] = child_nd
        return node

    def majorCnt(self, clsList):
        '''
        数据占比最多的分类
        :param clsList: 
        :return: 
        '''
        num_dict = defaultdict(int)
        for cls in clsList:
            num_dict[cls] += 1
        data = zip(num_dict.values(), num_dict.keys())
        sorted_data = sorted(data, reverse=True)
        return sorted_data[-1][1]

    def classfiy(self, vec, tree_root):
        '''
        对给定数据分类
        :param vec: 
        :param tree_root: 
        :return: 
        '''
        feat = tree_root['feat']
        labels = tree_root['label']
        cls = None
        while len(labels) > 0:
            v = vec[feat]
            node = labels[v]
            if 'cls' in node:
                cls = node['cls']
                break
            labels = node['label']
            feat = node['feat']
        return cls


class DecTreePlotter(object):
    '''
    绘制决策树类
    '''
    decNode = dict(boxstyle='square', fc='0.8')
    leafNode = dict(boxstyle='round4', fc='0.4')

    def __init__(self):
        super(DecTreePlotter, self).__init__()

    def draw(self, tree_root):
        width, height = self._getSize(tree_root)
        fig, ax = plt.subplots()
        ax.grid()
        pt = (0.5, 0.9)
        tree_root['loc'] = pt
        # 绘制根结点
        plt.text(pt[0], pt[1], 'feat:{}'.format(tree_root['feat']), horizontalalignment='center', size=10,
                 bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))
        self.draw_retrieve(ax, [tree_root], width, height, height-1)

    def draw_retrieve(self, ax, p_nodes, width, height, level):
        child_nodes = []
        index = 0
        cell_width = 1.0 / width - 0.1/width
        for pn in p_nodes:
            cur_nodes = pn['label'].values()
            p_pt = pn['loc']
            for i in range(len(cur_nodes)):
                nd = cur_nodes[i]
                if 'cls' not in nd:
                    child_nodes.append(nd)
                    txt = 'feat:{}'.format(str(nd['feat']))
                else:
                    txt = '{}'.format(nd['cls'])
                txt_pt = ((index + 1) * cell_width, float(level) * (1.0 / height))
                nd['loc'] = txt_pt
                node_type = DecTreePlotter.leafNode if 'cls' in nd else DecTreePlotter.decNode
                self.plotNode(ax, txt, txt_pt, p_pt, node_type)
                mid_pt = (txt_pt[0] / 2 + p_pt[0] / 2, txt_pt[1] / 2 + p_pt[1] / 2)
                plt.text(mid_pt[0], mid_pt[1], str(nd['lb']), color='red', size=20)
                index += 1

        if len(child_nodes) > 0:
            self.draw_retrieve(ax, child_nodes, width, height, level - 1)

    def _getSize(self, tree_root):
        cur_nodes = [tree_root]
        width = len(cur_nodes)
        height = 0
        while len(cur_nodes) > 0:
            tmp_nodes = []
            height += 1
            cur_width = 0
            for node in cur_nodes:
                label_nodes = node['label'].values()
                cur_width += len(label_nodes)
                tmp_nodes.extend([vo for vo in label_nodes if 'cls' not in vo])

            width = cur_width if cur_width > width else width
            cur_nodes = tmp_nodes
        return width, height + 1

    def plotNode(self, ax, nodeText, centerPt, parentPt, nodeType):
        print '{}-{}'.format(centerPt, parentPt)
        ax.annotate(nodeText, xy=parentPt, xycoords='axes fraction', \
                    xytext=centerPt, textcoords='axes fraction', \
                    va='center', ha='center', bbox=nodeType, arrowprops=dict(arrowstyle='<-',connectionstyle="arc,angleA=60,angleB=20,rad=0.0"))


tree = DecTree()
data_set, labels = tree.createDataSet()
root = tree.createTree(data_set, [0, 1, 2,3])
# 测试分类器
dataSet = [[0, 0, 0, 'maybe'], [1, 1, 1, 'yes'], [1, 0, 1, 'no'], [0, 1, 1, 'no'], [0, 0, 0, 'yes'],
           [1, 1, 1, 'yes']]
for vec in data_set:
    cls = tree.classfiy(vec, root)
    print 'vec:{},cls is {},real is {}'.format(vec, cls, vec[-1])

# 绘制决策树
tree_plotter = DecTreePlotter()
tree_plotter.draw(root)
plt.show()

上一篇 下一篇

猜你喜欢

热点阅读