XGBoost_源码初探

2018-09-09  本文已影响833人  xieyan0811

1. 说明

 本篇来读读Xgboost源码。其核心代码基本在src目录下,由C++实现,40几个cc文件,代码11000多行,虽然不算太多,但想把核心代码都读明白,也需要很长时间。 我觉得阅读的目的主要是:了解基本原理,流程,核心代码的位置,修改从哪儿入手,而得以快速入门。因此,需要跟踪代码执行过程,同时查看在某一步骤其内部环境的取值情况。具体方法是:单步调试或在代码中加入一些打印信息,因此选择了安装编译源码的方式。

2. 下载编译

 用参数--recursive可以下载它的支持包rabit和cur,否则编不过

$ git clone --recursive https://github.com/dmlc/xgboost
$ cd xgboost
$ make -j4

3. 运行

 测试程序demo目录中有多分类,二分类,回归等各种示例,这里从二分类入手。

$ cd demo 
#运行一个测试程序 
$ cd binary_classification
$ ./runexp.sh # 可以通过修改cfg文件,增加迭代次数等,进一步调试

4. 主流程

 下面从main()开始,看看程序执行的主要流程,下图是一个示意图,每个黄色框对应一个cc文件,可以将它视作调用关系图,并非完全按照类图绘制,同时省略了一些主流程以外的细节,请各位以领会精神为主。

1) Src/cli_main.cc:(主程序入口)

i. CLIRunTask():解析参数,提供三个主要功能:训练,打印模型,预测.

ii. CLITrain():训练部分,装载数据后,主要调用学习器Learner实际功能(配置cofigure,迭代,评估,存储……),其中的for循环包含迭代调用计算和评估。

2) Src/learner.cc:(学习器)

 定义三个核心句柄gbm_(子模型tree/linear),obj_(损失函数),metrics_(评价函数)

i. UpdateOneIter():此函数会在每次迭代时被调用,主要包含四个步骤:调整参数(LazyInitDMatrix()),用当前模型预测(PredictRaw(),gbm_-> PredictBatch()),求当前预测结果和实际值的差异的方向(obj_->GetGradient()),根据差异修改模型(gbm_->DoBoost()),后面逐一细化。

ii. EvalOneIter() 支持对多个评价数据集分别评价,对每个数据集,先进行预测(PredictRaw()),评价(obj_->EvalTransform()),再调metrics_中的各个评价器,输出结果。

3) Src/metric/metric.cc(评价函数入口)

 基本上,每个目录都有一个入口函数,metric.cc是评价函数的入口,learn允许同时支持多个评价函数(注意评价函数和误差函数不同)。主要三种评价函数:多分类,排序,元素评价,分别定义在三个文件之中。

4) Src/objective/objective.cc(损失函数入口)

 objective.cc是损失函数的入口,Learner::load()函数调用Create()创建误失函数,该目录中实现了:多分类,回归,排序的多种损失函数(每个对应一个文件),每个损函数最核心的功能是GetGradient(),另外也可以参考plugin中示例,自定义损失函数。 例如:src/objective/regression_obj.cc(最常用的损失函数RegLossObj())计算一阶导,二阶导,存入gpair结构。这里加入了样本的权重,scale_pos_weight也是在此处起作用。

5) src/gbm/gbm.cc(迭代器Gradient Booster)

 这里是对模型的封装,主要支持tree和linear两种方式,树分类器又包含GBTree和Dart两种,Dart主要加入了归一化和dropout防过拟合,详见参考部分。 gbm.cc中也有三个重要句柄:model_存储当前模型数据,updaters_管理每一次迭代的更新算法, predictor_用于预测

i. DoBoost()和BoostNewTrees() 进一步迭代生成新树,详建更新器部分

ii. Predict*() 调用各种预测,详见预测部分

6) src/predictor/predictor.cc(预测工具入口)

 predictor.cc也是一个入口,可调用cpu和gpu两种预测方式。

i. PredValue():核心函数,计算了从训练到当前迭代的所有回归树集合(以回归树为例)。

7) src/tree/tree_updater.cc(树模型的具体实现)

 src/tree和src/linear分别是树和线性模型的具体实现,tree_updater是updater的入口,每一个Updater是对一棵树进行一次更新。其中的Updater分为两类:计算类和辅助类,updater都继承于TreeUpdater,互相之间又有调有关系,比如:prune调用sync,colmaker和fast_hist调用prune。

 以下为辅助类:

i. Src/tree/updater_prune.cc 用于剪枝

ii. Src/tree/updater_refresh.cc 用于更新权重和统计值

iii. Src/tree/updater_sync.cc 用于在分布式系统的节点间同步数据

iv. Src/tree/split_evaluator.cc 定义了两种切分方法:弹性网络elastic net, 单调约束monotonic,在此为切分评分,正则项在此发挥作用。打发的依据是差值,权重和正则化项。

 以下为算法类(基本都在xgboost论文第三章描述)
 对于树算法,最核心的是如何选择特征和特征的切分点,具体原理请见CART,算法,信息增益,熵等概念,这里实现的是几种树的生成方法。

v. Src/tree/updater_colmaker.cc 贪婪搜索算法(Exact Greedy Algorithm),最基本的树算法,一般都用它举例说明,这里提供了分布和非分布两种支持。在每个特征中选择该特征下的每个值作为其分裂点,计算增益损失。由内至外,关键函数分别是: EnumerateSplit() 穷举每一个枚举值,用split_evaluator打分。 ParallelFindSplit() 多线程,其它同上 UpdateSolution() 调上面两个split(),更新候选方案 FindSplit() 在当前层寻找最佳切分点,对比各个候选方案,方案来自上面的UpdateSolution()

vi. Src/tree/updater_histmaker.cc 它是xgboost默认的树生成算法, 它和后面提到的skmaker都继承自BaseMaker(BaseMaker的父类是TreeUpdate)是基于直方图选择特征切分点。 HistMaker提取Local和Global两种方式,Global是学习每棵树前, 提出候选切分点;Local是每次分裂前,重新提出候选切分点。 UpdateHistCol() 对每一个col,做直方图分箱,返回一个分界Entry列表。

vii. Src/tree/updater_skmaker.cc 继承自BaseMaker(BaseMaker父类TreeUpdate)加权分位数草图,用子集替代全集,使用近似的 sketch 方法寻找最佳分裂点。

5. 其它

1) GPU,多线程,分布式

 代码中也有大量操作GPU,多线程,分布式的操作,这里主要介绍核心流程,就没有提及,详见代码,其中.cu和.cuh是主要针对GPU的程序。

2) 关键字说明

 CSR:csr_matrix一种存储格式
 Dmlc(Deep Machine Learning in Common):分布式深度机器学习开源项目
 Rabit:可容错的allrecude(分布式),支持python和C++,可以运行在包括MPI和Hadoop 等各种平台上面
 Objective与Metric(Eval):这里的Metric和Eval都指评价函数,Objective指损失函数,它们计算的都是实际值和预测值之间的差异,只是用途不同,Objective主要在生成树时使用,用于计算误差和通过误差的方向调整树;而评价函数主要用于判断模型对数据的拟合程度,有时通过它判断何时停止迭代。

3) 基于直方图的切分点选择

 分位数quantiles:即把概率分布划分为连续的区间,每个区间的概率相同。把数值进行排序,然后根据你采用的几分位数把数据分为几份即可。
 xgboost用二阶导h对分位数进行加权,让相邻两个候选分裂点相差不超过某个值ε。因此,总共会得到1/ε个切分点。
 通过特征的分布,按照加权直方图算法确定一组候选分裂点,通过遍历所有的候选分裂点来找到最佳分裂点。它不会枚举所有的特征值,而是对特征值进行聚合统计,然后形成若干个bucket(桶),只将bucket边界上的特征值作为split point的候选,从而获得性能提升,对稀疏数据效果好。

6. 参考

1) XGBoost Documentation

https://xgboost.readthedocs.io/en/latest/

2) xgboost入门与实战(原理篇)

https://blog.csdn.net/sb19931201/article/details/52557382

3) XGBoost解析系列--源码主流程

https://blog.csdn.net/matrix_zzl/article/details/78699605

4) XGBoost 论文翻译+个人注释

https://blog.csdn.net/qdbszsj/article/details/79615712

5) DART booster

https://blog.csdn.net/Yongchun_Zhu/article/details/78745529

6) 『我爱机器学习』集成学习(三)XGBoost

https://www.hrwhisper.me/machine-learning-xgboost/

上一篇下一篇

猜你喜欢

热点阅读