web服务器

数据科学(机器学习: 应用)

2018-11-26  本文已影响81人  GHope

Sklearn之使用决策树预测隐形眼镜类型

眼科医生是如何判断患者需要佩戴隐形眼镜的类型的?用决策树我们可以帮助人们判断需要佩戴的镜片类型

lenses.txt
一共有24组数据,数据的Labels依次是age、prescript、astigmatic、tearRate、class,也就 是第一列是年龄,第二列是症状,第三列是是否散光,第四列是眼泪数量,第五列是最终的分类标签

使用Sklearn构建决策树

sklearn.tree.DecisionTreeClassifier

使用sklearn决策树 fit函数之前,需要对数据集编码

为了对string类型的数据序列化,需要先生成pandas数据,这样方便我们的序列化工作。这里我使用的方法是,原始数据->字典->pandas数据,编写代码如下

import pandas as pd
from sklearn.preprocessing import LabelEncoder
with open('examples/lenses.txt', 'r') as fr:                                        #加载文件
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]        #处理文件
lenses_target = []                                                        #提取每组数据的类别,保存在列表里 
for each in lenses:
    lenses_target.append(each[-1])
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']            #特征标签
lenses_list = []                                                        #保存lenses数据的临时列
lenses_dict = {}                                                        #保存lenses数据的字典,用于生成pandas
for each_label in lensesLabels:                                            #提取信息,生成字典
    for each in lenses:
        lenses_list.append(each[lensesLabels.index(each_label)])
    lenses_dict[each_label] = lenses_list
    lenses_list = [] 
print(lenses_dict)                                                        #打印字典信息 
lenses_pd = pd.DataFrame(lenses_dict)                                    #生成pandas.DataFrame 
print(lenses_pd)   
原始数据->字典->pandas数据

将数据序列化

le = LabelEncoder()                                                        #创建LabelEncoder()对象,用于序列化
for col in lenses_pd.columns:                                         #为每一 列序列化 
    lenses_pd[col] = le.fit_transform(lenses_pd[col]) 
print(lenses_pd)   
序列化结果

使用Graphviz可视化决策树

Graphviz的是AT&T Labs Research开发的图形绘制工具,他可以很方便的用来绘制结构化的图形网络,支持多种格式输出,生成图片的质量和速度都不错。它的输入是一个用dot语言编写的绘图脚本,通过对输入脚本的解析,分析出其中的点,边以及子图,然后根据属性进行绘制。是使用Sklearn生成的决策树就是dot格式的,因此我们可以直接利用Graphviz将决策树可视化。

conda install Graphviz
conda install pydotplus
from sklearn.preprocessing import LabelEncoder, OneHotEncoder 
from sklearn.externals.six import StringIO
from sklearn import tree 
import pandas as pd 
import numpy as np 
import pydotplus
with open('examples/lenses.txt', 'r') as fr:                                        #加载文件
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]        #处理文件
lenses_target = []                                                        #提取每组数据的类别,保存在列表里
for each in lenses:
    lenses_target.append(each[-1]) 
print(lenses_target)
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']            #特征标签
lenses_list = []                                                        #保存 enses数据的临时列表 
lenses_dict = {}                                                        #保存 lenses数据的字典,用于生成pandas 
for each_label in lensesLabels:                                            #提取信息,生成字典
    for each in lenses:
        lenses_list.append(each[lensesLabels.index(each_label)])
    lenses_dict[each_label] = lenses_list
    lenses_list = [] 
# print(lenses_dict)                                                        # 印字典信息 
lenses_pd = pd.DataFrame(lenses_dict)                                    #生成 pandas.DataFrame 
# print(lenses_pd)                                                        #打印pandas.DataFrame
le = LabelEncoder()                                                        #创建LabelEncoder()对象,用于序列化           
for col in lenses_pd.columns:                                            #序列化    
    lenses_pd[col] = le.fit_transform(lenses_pd[col]) 
# print(lenses_pd)                                                        #打印 编码信息
clf = tree.DecisionTreeClassifier(max_depth = 4)                        #创建 DecisionTreeClassifier()类 
clf = clf.fit(lenses_pd.values.tolist(), lenses_target)                    #使用数据,构建决策树 
dot_data = StringIO() 
tree.export_graphviz(clf, out_file = dot_data,                            #绘制 决策树
                    feature_names = lenses_pd.keys(),
                    class_names = clf.classes_,
                    filled=True, rounded=True,
                    special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
graph.write_pdf("tree.pdf")                                                #保存绘制好的决策树,以PDF的形式存储。
上一篇下一篇

猜你喜欢

热点阅读