BaggingClassifier

2018-02-08  本文已影响0人  taojinglong

写在前面

Ensemble methods 组合模型的方式大致为四个:/bagging / boosting / voting / stacking ,此文主要简单叙述 bagging算法。


算法主要特点

Bagging:


接下来进入主题

Bagging 算法:

WIKI百科:
Bagging算法 (英语:Bootstrap aggregating,引导聚集算法),又称装袋算法,是机器学习领域的一种团体学习算法。最初由Leo Breiman于1994年提出。Bagging算法可与其他分类、回归算法结合,提高其准确率、稳定性的同时,通过降低结果的方差,避免过拟合的发生。


实现原理:

  1. 数学基础


    这里写图片描述
  2. 图例描述


    这里写图片描述
  3. 实现描述

    在scikit-learn中,
    参数 max_samples 和 max_features 控制子集的大小(在样本和特征方面)
    参数 bootstrap 和 bootstrap_features 控制是否在有或没有替换的情况下绘制样本和特征。


实例分析:

  1. 实例环境

    sklearn + anconda + jupyter

  2. 实例步骤

    • 数据:可以采用 datasets 的数据,在此作者使用的是自己整理的股票行情
    • 训练、测试数据归一化
    • 参数寻优可以使用GridSearch,在此不作赘述

    参数描述:
    [图片上传失败...(image-2e684a-1518054828425)]

  3. 代码实现

import time
import pandas as pd
from pandas import Series,DataFrame
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import AdaBoostClassifier
from sklearn.model_selection import cross_val_score
from sklearn import preprocessing
from sklearn import datasets
iris = datasets.load_iris()
X,y = iris.data[:,1:3],iris.target
start = time.clock()  # 计时
min_max_scaler = preprocessing.MinMaxScaler()

# 读取训练数据 并数据规整化
raw_data  = pd.read_csv('train_data.csv') 
raw_datax = raw_data[:20000]
X1_scaled = min_max_scaler.fit_transform(raw_datax.ix[:,3:7])
y1 = raw_datax['Y1']
y1 = list(y1)

# 读取测试数据 并数据规整化
raw_datat  = pd.read_csv('test_data.csv')
raw_datatx = raw_datat[:10000]
X1t_scaled = min_max_scaler.fit_transform(raw_datatx.ix[:,3:7])
y1t = raw_datatx['Y1']
y1t = list(y1t)

print len(X1_scaled)
print len(X1t_scaled)
end = time.clock()
print '运行时间:',end - start
clf = DecisionTreeClassifier().fit(X1_scaled,y1)
clfb = BaggingClassifier(base_estimator= DecisionTreeClassifier()
                         ,max_samples=0.5,max_features=0.5).fit(X1_scaled,y1)

predict = clf.predict(X1t_scaled)
predictb = clfb.predict(X1t_scaled)

print clf.score(X1t_scaled,y1t)
print clfb.score(X1t_scaled,y1t)

# print Series(predict).value_counts()
# print Series(predictb).value_counts()

[图片上传失败...(image-790f8-1518054828425)]

方法总结

1.运用注意点
2.优化方向点

资料参考:http://blog.csdn.net/qq_30189255/article/details/51532442

上一篇 下一篇

猜你喜欢

热点阅读