决策树底层实现及绘制(python)
2018-05-17 本文已影响0人
吃番茄的土拨鼠
# coding:utf-8
import matplotlib
import matplotlib.pyplot as plt
from collections import defaultdict
from math import log
import matplotlib.path as mpath
import matplotlib.patches as mpatches
import numpy as np
from matplotlib import font_manager as fm, rcParams
class DecTree:
def __init__(self):
pass
# 计算香浓熵
def cacuChannoEnt(self, data_set):
cls_dict = defaultdict(int)
for data in data_set:
cls_dict[data[-1]] += 1
total = len(data_set)
channo_ent = 0.0
for num in cls_dict.values():
p = float(num) / total
channo_ent -= p * log(p, 2)
return channo_ent
def createDataSet(self):
dataSet = [[0, 0, 0, 0, 'maybe'], [0, 1, 1, 1, 'error'], [1, 1, 1, 0, 'yes'], [1, 0, 0, 1, 'no'],
[1, 0, 1, 1, 'no'],
[1, 0, 0, 0, 'yes'], [1, 0, 0, 0, 'maybe'],
[1, 1, 1, 0, 'yes']]
labels = ['no surfaceing', 'flippers', 'flippers']
return dataSet, labels
def splitDataSet(self, data_set, feat, value):
'''
抽取某个feat值为value的数据子集
:param data_set:
:param feat:
:param value:
:return:
'''
sub_data_set = []
for data in data_set:
if data[feat] == value:
sub_feat = data[:feat]
sub_feat.extend(data[feat + 1:])
sub_data_set.append(sub_feat)
return sub_data_set
def chooseBestFeature(self, data_set):
'''
获取能产生最大信息增益的feat
:param data_set:
:return:
'''
base_ent = self.cacuChannoEnt(data_set)
feat_num = len(data_set[0]) - 1
best_feat = 0
total = len(data_set)
ent_gain = 0.0
for i in range(feat_num):
uni_vals = set([data[i] for data in data_set])
ent_tmp = 0
for v in uni_vals:
sub_data_set = self.splitDataSet(data_set, i, v)
p = float(len(sub_data_set)) / total
ent_tmp += p * self.cacuChannoEnt(sub_data_set)
cur_gain = base_ent - ent_tmp
if cur_gain > ent_gain:
best_feat = i
ent_gain = cur_gain
return best_feat
def allCls(self, data_set):
'''
计算数据集中全部类别
:param data_set:
:return:
'''
v_list = [data[-1] for data in data_set]
s = set(v_list)
return s
def createTree(self, data_set, feat_list):
'''
创建决策树
:param data_set:数据集
:param feat_list: feat集合
:return:
'''
node = {}
if len(feat_list) == 0:
data_cls = [data[-1] for data in data_set]
return {'cls': self.majorCnt(data_cls)}
all_cls = self.allCls(data_set)
if len(all_cls) == 1:
return {'cls': all_cls.pop()}
feat = self.chooseBestFeature(data_set)
uni_vals = set([v[feat] for v in data_set])
node['feat'] = feat_list[feat]
node['label'] = {}
for v in uni_vals:
sub_dat_set = self.splitDataSet(data_set, feat, v)
sub_feat_list = feat_list[:feat]
sub_feat_list.extend(feat_list[feat + 1:])
child_nd = self.createTree(sub_dat_set, sub_feat_list)
child_nd['lb'] = v
node['label'][v] = child_nd
return node
def majorCnt(self, clsList):
'''
数据占比最多的分类
:param clsList:
:return:
'''
num_dict = defaultdict(int)
for cls in clsList:
num_dict[cls] += 1
data = zip(num_dict.values(), num_dict.keys())
sorted_data = sorted(data, reverse=True)
return sorted_data[-1][1]
def classfiy(self, vec, tree_root):
'''
对给定数据分类
:param vec:
:param tree_root:
:return:
'''
feat = tree_root['feat']
labels = tree_root['label']
cls = None
while len(labels) > 0:
v = vec[feat]
node = labels[v]
if 'cls' in node:
cls = node['cls']
break
labels = node['label']
feat = node['feat']
return cls
class DecTreePlotter(object):
'''
绘制决策树类
'''
decNode = dict(boxstyle='square', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.4')
def __init__(self):
super(DecTreePlotter, self).__init__()
def draw(self, tree_root):
width, height = self._getSize(tree_root)
fig, ax = plt.subplots()
ax.grid()
pt = (0.5, 0.9)
tree_root['loc'] = pt
# 绘制根结点
plt.text(pt[0], pt[1], 'feat:{}'.format(tree_root['feat']), horizontalalignment='center', size=10,
bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))
self.draw_retrieve(ax, [tree_root], width, height, height-1)
def draw_retrieve(self, ax, p_nodes, width, height, level):
child_nodes = []
index = 0
cell_width = 1.0 / width - 0.1/width
for pn in p_nodes:
cur_nodes = pn['label'].values()
p_pt = pn['loc']
for i in range(len(cur_nodes)):
nd = cur_nodes[i]
if 'cls' not in nd:
child_nodes.append(nd)
txt = 'feat:{}'.format(str(nd['feat']))
else:
txt = '{}'.format(nd['cls'])
txt_pt = ((index + 1) * cell_width, float(level) * (1.0 / height))
nd['loc'] = txt_pt
node_type = DecTreePlotter.leafNode if 'cls' in nd else DecTreePlotter.decNode
self.plotNode(ax, txt, txt_pt, p_pt, node_type)
mid_pt = (txt_pt[0] / 2 + p_pt[0] / 2, txt_pt[1] / 2 + p_pt[1] / 2)
plt.text(mid_pt[0], mid_pt[1], str(nd['lb']), color='red', size=20)
index += 1
if len(child_nodes) > 0:
self.draw_retrieve(ax, child_nodes, width, height, level - 1)
def _getSize(self, tree_root):
cur_nodes = [tree_root]
width = len(cur_nodes)
height = 0
while len(cur_nodes) > 0:
tmp_nodes = []
height += 1
cur_width = 0
for node in cur_nodes:
label_nodes = node['label'].values()
cur_width += len(label_nodes)
tmp_nodes.extend([vo for vo in label_nodes if 'cls' not in vo])
width = cur_width if cur_width > width else width
cur_nodes = tmp_nodes
return width, height + 1
def plotNode(self, ax, nodeText, centerPt, parentPt, nodeType):
print '{}-{}'.format(centerPt, parentPt)
ax.annotate(nodeText, xy=parentPt, xycoords='axes fraction', \
xytext=centerPt, textcoords='axes fraction', \
va='center', ha='center', bbox=nodeType, arrowprops=dict(arrowstyle='<-',connectionstyle="arc,angleA=60,angleB=20,rad=0.0"))
tree = DecTree()
data_set, labels = tree.createDataSet()
root = tree.createTree(data_set, [0, 1, 2,3])
# 测试分类器
dataSet = [[0, 0, 0, 'maybe'], [1, 1, 1, 'yes'], [1, 0, 1, 'no'], [0, 1, 1, 'no'], [0, 0, 0, 'yes'],
[1, 1, 1, 'yes']]
for vec in data_set:
cls = tree.classfiy(vec, root)
print 'vec:{},cls is {},real is {}'.format(vec, cls, vec[-1])
# 绘制决策树
tree_plotter = DecTreePlotter()
tree_plotter.draw(root)
plt.show()