Spark.GBDT学习-GBTClassifier

2018-07-03  本文已影响0人  松鼠胃口好

用于分类的GBT(Gradient-Boosted Trees)算法,基于J.H. Friedman. "Stochastic Gradient Boosting"实现,目前不支持多分类任务。Gradient Boosting vs. TreeBoost:

  • 本实现基于Stochastic Gradient Boosting(随机梯度提升),而不是TreeBoost
  • 两种方法都是通过最小化损失函数,学习树的集成
  • TreeBoost方法相对于原始方法,基于损失函数对叶节点的输出进行了额外的修改
  • Spark考虑未来实现TreeBoost

GBTClassifier

定义

一个唯一标识uid,继承了Predictor类,继承了GBTClassifierParamsDefaultParamsWritableLogging特质。其中Predictor中的三个元素分别代表: 特征类型、学习器、学习到用于预测的模型

class GBTClassifier(override val uid: String) 
extends Predictor[Vector, GBTClassifier, GBTClassificationModel] 
with GBTClassifierParams with DefaultParamsWritable with Logging 
{
    def this() = this(Identifiable.randomUID("gbtc"))
    ...
}

参数

为了兼容JAVA API,覆盖了继承自特质(with trait)的参数setter方法。

  1. TreeClassifierParams参数
  1. TreeEnsembleParams参数
  1. GBTParams参数
  1. GBTClassifierParams参数

方法

  1. copy方法
    GBTClassifier的拷贝函数。
  2. train方法
    GBTClassifier类的主要方法,用于训练得到学习好的用于预测的模型。
// @input: 训练数据, DataSet
// @output: 学习到的模型, GBTClassificationModel
override protected def train(dataset: Dataset[_]):
GBTClassificationModel = {
    // 得到类别特征
    val categoricalFeatures: Map[Int, Int] =
    MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    // 转换训练数据并进行验证
    // 将DataSet转换成RDD[LabeledPoint]
    // 只支持二分类,要求label为0或1
    val oldDataset: RDD[LabeledPoint] =
        dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
            case Row(label: Double, features: Vector) =>
                require(label == 0 || label == 1, s"GBTClassifier was given dataset with invalid label $label.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.")
            LabeledPoint(label, features)
        }
    // 获得特征个数及boosting策略
    val numFeatures = oldDataset.first().features.size
    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
    // 用于记录日志
    val instr = Instrumentation.create(this, oldDataset)
    instr.logParams(params: _*)
    instr.logNumFeatures(numFeatures)
    instr.logNumClasses(2)
    // 调用GradientBoostedTrees训练得到一组学习器及其权重
    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))
    // 将学到的模型封装成GBTClassificationModel并返回
    val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
    instr.logSuccess(m)
    m
}

GBTClassifier对象

object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
    // final变量,访问支持的损失函数类型
    final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
    // 从目录中加载GBTClassifier
    override def load(path: String): GBTClassifier = super.load(path)
}

GBTClassificationModel

用于分类的GBT模型,仅支持二分类,支持连续特征和类别特征。

定义

继承了PredictionModel类以及多个特质,其中PredictionModel的两个元素分别代表特征类型、学习到用于预测的模型

class GBTClassificationModel private[ml](
    override val uid: String,
    private val _trees: Array[DecisionTreeRegressionModel],
    private val _treeWeights: Array[Double],
    override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with GBTClassifierParams 
with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable 
{
    // 检查_trees.nonEmpty
    // 检查_trees.length == _treeWeights.length
    val numTrees: Int = _trees.length
    ...
}

方法

  1. transformImpl方法
    首先将GBTClassificationModel进行广播,然后通过udf进行预测数据,udf中调用predict方法实现。
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    // 广播本类
    val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
        // udf通过本类的predict方法实现
        bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    // 使用udf将特征数据转换成预测数据
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }
  1. predict方法
    关键的预测方法,先得到每个基学习器的预测值,然后进行融合得到最终的预测结果,最后得到类别结果。可以看到这里得到的预测值不是概率而是类别0/1,因为label被转换成了-1/+1,所以这里通过prediction>0.0得到预测lebel。
override protected def predict(features: Vector): Double = {
    // 得到每棵树的预测结果
    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
    // 乘以权重之后求和得到融合结果
    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
    // 得到预测lebel
    if (prediction > 0.0) 1.0 else 0.0
  }
  1. copy方法
    GBTClassificationModel的拷贝方法。
  2. toOld方法
    将ml的模型转换成mllib中老的API,ml域的私有方法。
private[ml] def toOld: OldGBTModel = {
    new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
  1. write方法
    调用GBTClassificationModel对象的方法保存本模型。
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)

GBTClassificationModel对象

  1. fromOld方法
    从老的API中转换出当前模型
  2. GBTClassificationModelReader
    私有类,其中的load方法用于从目录中读取模型
  3. GBTClassificationModelWriter
    私有类,其中的saveImpl方法用于将本模型保存
  4. read方法
    新建GBTClassificationModelReader
  5. load方法
上一篇 下一篇

猜你喜欢

热点阅读