学习scikit-learn

scikit-learn--Decision Trees(决策树

2017-05-29  本文已影响65人  DayDayUp_hhxx

决策树是一种用来分类、回归、非参数、有监督的学习方法。

决策树的优点:

1.简单易懂,可视化展现;
2.相比其他方法,仅需要很少的数据准备工作,但是不支持缺失值;
3.决策树的时间复杂度是用来训练决策树的数据点的对数(The cost of using the tree (i.e., predicting data) is logarithmic in the number of data points used to train the tree.);
4.能够处理数值型和分类型数据;
5.能够解决多个输出(multi-output)问题;
6.是一种white box模型,通过树的逻辑关系,很容易解释输出结果;
7.使用统计检验验证模型的效果;
8.即使一些假设与数据生成的真实模型是相反的,也能有较好的预测结果。(Performs well even if its assumptions are somewhat violated by the true model from which the data were generated)

决策树的不足:

1.容易过拟合,生成很复杂的树,可以通过设置叶子节点的最小样本数或者树的最大深度去避免过拟合;
2.决策树是不稳定的,因为数据的很小变化就可能导致完全不同的结果,可以使用ensemble方法解决这个问题;
3.训练最优决策树是完全NP问题,实际的决策树学习算法是基于启发式算法(比如贪婪算法)寻求每个节点的局部最优解,不能得到全局最优解;可以使用随机森林解决这个问题;
4.决策树很难学习以下问题,比如XOR,parity or multiplexer problems;
5.如果一些类处于支配地位,将得到有偏的决策树,因此应在拟合决策树之前平衡数据集。

Tips

1.当特征很多的时候,决策树很容易过拟合,知道特征对应的样本占比非常重要,因为一棵小样本的树在高维空间很可能过拟合;
2.考虑维归约(PCA,ICA,特征选择),能让决策树发现更有判别力的特征;
3.训练模型时使用export函数可视化决策树,初始化的时候可以选择深度为3,然后逐渐增加深度;
4.记住树的深度每增加一层所需要的样本数,使用max_depth控制树的规模来避免过拟合;
5.使用 min_samples_split 或者 min_samples_leaf 控制叶子节点数,一个很小的值通常导致树过拟合,而一个较大的值将不会得到很好的学习效果。通常使用min_samples_leaf=5作为初始值,如果样本大小变化很大,也可以使用百分比作为参数。min_samples_leaf 决定了叶子节点的最小样本数,而 min_samples_split 能生成更小的叶子节点,因此min_samples_split 更普遍。
6.在训练前平衡你的数据集,以防止树偏向于占主导地位的类;可以为每类抽取相同的样本数,或者给每类相同的样本权重(sample_weight);同时也要注意到,基于权重的预减枝准则(例如min_weight_fraction_leaf)对于占主导地位的类产生更小的偏倚,相比没有考虑样本权重的min_samples_leaf ;
7.如果使用了样本加权,将很容易使用基于权重的预减枝准则(例如min_weight_fraction_leaf)对树进行优化;
8.决策树内部使用 np.float32 数组,如果训练数据集不是这种类型,将生成副本;
9.如果输入矩阵是稀疏的,建议在拟合之前转化为稀疏的csc_matrix,以及预测前转换为csr_matrix。相比很多样本中的特征有很多零值的稠密矩阵,稀疏矩阵的训练时间将有数量级的提升。

可视化例子

from sklearn.datasets import load_iris
from sklearn import tree
import sys
import os     
import pydotplus 
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
from IPython.display import Image  
dot_data = tree.export_graphviz(clf, out_file=None, 
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
graph = pydotplus.graph_from_dot_data(dot_data)  
Image(graph.create_png())

来源:http://scikit-learn.org/stable/modules/tree.html

上一篇 下一篇

猜你喜欢

热点阅读