随机森林

随机森林原理

2018-10-29  本文已影响20人  阿童89

1、随机森林步骤
1)给定包含N个样本的数据集,经过m次有放回的随机抽样操作,得到T个含m个训练样本的采样集
2)对每个采样集,从所有属性中随机选择k个属性,选择最佳分割属性作为节点建立CART模型,最终建立拥有T个CART模型的随机森林
注:k一般选择(其中d为样本所有属性):
k=log_{2}^{d}
3)将模型用于测试机,对于测试每个样本会有T个预测值,对分类任务使用简单投票法确定该样本最终预测值,对回归任务使用简单平均法确定该样本最终预测值

2、特征重要性
1)对整个随机森林,得到相应的袋外数据(out of bag,OOB)​计算袋外数据误差,记为errOOB1.
注:每个采样集只使用了初始训练集中约63.2%的样本【每个样本被抽到的概率是1/N,样本不被抽到概率就是1-1/N,总共抽了m次,第m次试验后样本不被抽到的概率是(1-1/N)m,当m趋近于无穷大时,(1-1/n)m=1/e,约等于36.8%】,另外抽不到的样本叫做out-of-bag(OOB) examples,这部分数据可以用于对决策树的性能进行评估,计算模型的预测错误率,称为袋外数据误差。这已经经过证明是无偏估计的,所以在随机森林算法中不需要再进行交叉验证或者单独的测试集来获取测试集误差的无偏估计

2)随机对袋外数据OOB所有样本的特征X加入噪声干扰(可以随机改变样本在特征X处的值),再次计算袋外数据误差,记为errOOB2。假设森林中有N棵树,则特征X的重要性=∑errOOB2−errOOB1N∑errOOB2−errOOB1N。这个数值之所以能够说明特征的重要性是因为,如果加入随机噪声后,袋外数据准确率大幅度下降(即errOOB2上升),说明这个特征对于样本的预测结果有很大影响,进而说明重要程度比较高。

3)在特征重要性的基础上,特征选择的步骤如下:
a)计算每个特征的重要性,并按降序排序
b)确定要剔除的比例,依据特征重要性剔除相应比例的特征,得到一个新的特征集
c)用新的特征集重复上述过程,直到剩下m个特征(m为提前设定的值)。
d)根据上述过程中得到的各个特征集和特征集对应的袋外误差率,选择袋外误差率最低的特征集。​

3、随机森林优点
随机森林中的基学习器多样性不仅来自样本扰动,还来自属性的扰动,这就使得最终模型的泛化性能可通过个体学习器之间的差异度增加而进一步提升

3、python代码
class sklearn.ensemble.RandomForestClassifier(
n_estimators=10, criterion='gini', max_depth=None, min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features='auto', max_leaf_nodes=None, min_impurity_split=1e-07, bootstrap=True, oob_score=False, n_jobs=1, random_state=None, verbose=0, warm_start=False,class_weight=None)
以下用常用参数:
1)n_estimators:设置多少个基分类器(取决于数据量)
2)min_samplies_split:单独叶子节点至少要有几个样本,
3)max_features:一次抽样抽多少feature,回归问题设置为特征数,分类问题设置为sqrt(n_features)
4)max_depth:树的最大深度(5-10)
5)oob_score

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import precision_recall_curve
rf=RandomForestClassifier(n_estimators=10, criterion='gini', max_depth=5,
min_samples_split=2,  oob_score=True, n_jobs=1 )
param_grid ={"n_estimators":[5,10],'max_depth':[3,5]}
gscv = GridSearchCV(clf,param_grid,n_jobs= -1,verbose = 1,cv = 5,error_score = 0,scoring='auc')
gscv.fit(X,y)
gscv.best_score_
gscv.best_params_
gscv.predict_proba(X)#refit=True,gscv为最佳分类器
gscv.grid_scores_#score=roc_auc,auc值
上一篇下一篇

猜你喜欢

热点阅读