Spark.GBDT学习-GradientBoostedTree

2018-07-03  本文已影响0人  松鼠胃口好

run方法

根据任务类型训练得到一组弱学习器及对应的权重。分类任务(目前只能处理二分类)和回归任务调用的是相同的方法进行训练,分类任务可以看作是取值范围为[-1, +1]的回归任务。基学习器是DecisionTreeRegressionModel

// Method to train a gradient boosting model
// @input: 训练数据集, RDD[LabelPoint]
// @input: boosting策略, boostingStrategy
// @input: 随机数种子, seed
// @output: (array of decision tree models, array of model weights)
def run(
    input: RDD[LabeledPoint], 
    boostingStrategy: OldBoostingStrategy, 
    seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
    val algo = boostingStrategy.treeStrategy.algo
    // 根据boosting策略选择回归还是分类
    // 分类和回归调用的是同一个方法,唯一的区别就是需要将分类任务的label进行转换
    algo match {
        case OldAlgo.Regression =>
            GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed)
        case OldAlgo.Classification =>
            // 分类任务, 需要先将label映射为-1, +1. 这样二分类就可以看作是[-1, +1]的回归问题
            val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
            GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, seed)
        case _ =>
            throw new IllegalArgumentException(s"$algo is not supported by gradient boosting.") 
    }
}

runWithValidation方法

基于验证集的训练方法。同run方法唯一的区别就是增加了验证集。验证集要和训练集不同并且符合相同的分布(e.g.通过randomSplit方法得到的两个数据集)。验证集的功能就是:

// 和`run`方法的区别是调用`boost`方法时,参数validate的默认值为true,并且对于分类任务训练集和验证集都要进行label转换. 当validate为false时,验证集是无效的
GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed)

computeInitialPredictionAndError方法

计算gradient boosting第一次迭代产生模型(学习到的第一棵树)预测值以及误差。

// @input: 训练数据集, RDD[LabeledPoint]
// @input: 第一棵树的学习率(权重), Double
// @input: 第一棵树, DecisionTreeRegressionModel
// @input: 评价标准(evaluation metric), Loss
// @output: RDD(Tuple2(prediction, error))
def computeInitialPredictionAndError(
    data: RDD[LabeledPoint],
    initTreeWeight: Double,
    initTree: DecisionTreeRegressionModel,
    loss: OldLoss): RDD[(Double, Double)] = {
    data.map { lp =>
        // 调用updatePrediction得到预测结果, 已有的预测值为0.0
        val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight)
        // 调用Loss计算预测误差
        val error = loss.computeError(pred, lp.label)
        (pred, error)
    }
}

updatePrediction方法

将新一轮boosting迭代产生模型的预测值累加到之前的预测值上。

// @input: 特征,Vector
// @input: 已有的预测值(通过累加的方式集成多个弱学习的学习结果)
// @input: 新的决策树模型, DecisionTreeRegressionModel
// @input: 新模型的权重(学习率)
def updatePrediction(
    features: Vector,
    prediction: Double,
    tree: DecisionTreeRegressionModel,
    weight: Double): Double = {
    // 调用决策树的预测方法得到预测值,乘以学习率,累加到已有的预测值上
    prediction + tree.rootNode.predictImpl(features).prediction * weight
}

updatePredictionError方法

根据新一轮boosting迭代产生的模型调用updatePrediction方法得到新的预测值,并计算新的误差。

// @input: 训练数据, 新的决策树模型及权重, 评估方法
// @input: 上一轮的(预测值, 误差), 用于更新新的预测值
def updatePredictionError(
    data: RDD[LabeledPoint],
    predictionAndError: RDD[(Double, Double)],
    treeWeight: Double,
    tree: DecisionTreeRegressionModel,
    loss: OldLoss): RDD[(Double, Double)] = {
    val newPredError = data.zip(predictionAndError).mapPartitions {
        iter => iter.map {
            // zip之后, 形成(key, value), 第一个RDD的元素是key
            case (lp, (pred, error)) => 
                val newPred = updatePrediction(lp.features, pred, tree, treeWeight)
                val newError = loss.computeError(newPred, lp.label)
                (newPred, newError)
        }
    }
    newPredError
}

computeError方法

计算GBT的误差,该方法没有在算法中使用,但是对于Debug很有用。计算输入数据的平均误差。

// @input: 输入数据, 基学习数组, 权重数组, 评估方法
// @output: 平均误差
def computeError(
    data: RDD[LabeledPoint],
    trees: Array[DecisionTreeRegressionModel],
    treeWeights: Array[Double],
    loss: OldLoss): Double = {
    data.map { lp =>
        // 计算预测值foldLeft从左边遍历元素(model, weight), 初始值0, 得到累加的预测值
        val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) =>
            updatePrediction(lp.features, acc, model, weight)
        }
        // 得到预测值之后计算误差
        loss.computeError(predicted, lp.label)
    }.mean()
}

evaluateEachIteration方法

gradient boosting的每一次迭代计算误差或损失,相当于evaluate方法。该方法好像没有被调用过。

// @input: 输入数据, 基学习数组, 权重数组, 评估方法, 算法类型(回归/分类)
def evaluateEachIteration(
    data: RDD[LabeledPoint],
    trees: Array[DecisionTreeRegressionModel],
    treeWeights: Array[Double],
    loss: OldLoss,
    algo: OldAlgo.Value): Array[Double] = {
    val sc = data.sparkContext
    // 对于二分类任务需要将label映射到-1/+1
    val remappedData = algo match {
        case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
        case _ => data
    }
    // 广播trees, Kryo序列化可以注册该类
    val broadcastTrees = sc.broadcast(trees)
    val localTreeWeights = treeWeights
    val treesIndices = trees.indices
    val dataCount = remappedData.count()
    // 计算每一轮迭代的平均误差
    val evaluation = remappedData.map { point =>
        treesIndices.map { idx =>
            // 计算每一个基学习器的预测值
            val prediction = broadcastTrees.value(idx)
                .rootNode
                .predictImpl(point.features)
                .prediction
            prediction * localTreeWeights(idx)
        }
        // 累加得到每一轮的预测值
        .scanLeft(0.0)(_ + _).drop(1)
        // 计算得到每一轮的误差
        .map(prediction => loss.computeError(prediction, point.label))
    }
    // 计算所有数据每一轮的平均误差
    .aggregate(treesIndices.map(_ => 0.0))(
        (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
        (a, b) => treesIndices.map(idx => a(idx) + b(idx)))
    .map(_ / dataCount)
    broadcastTrees.destroy()
    evaluation.toArray
}

该方法中的scanLeftaggregate方法需要额外说明。

// reduceLeft/Right没有初始值
// foldLeft/Right有初始值
// scanLeft/Right得到累积的中间结果的集合
val abc = List("A", "B", "C")
def add(res: String, x: String) = { 
  println(s"op: $res + $x = ${res + x}")
  res + x
}
abc.reduceLeft(add)
// op: A + B = AB
// op: AB + C = ABC
// res: String = ABC
abc.foldLeft("z")(add)
// op: z + A = zA
// op: zA + B = zAB
// op: zAB + C = zABC
// res: String = zABC
abc.scanLeft("z")(add)
// op: z + A = zA
// op: zA + B = zAB
// op: zAB + C = zABC
// res: List[String] = List(z, zA, zAB, zABC) 
def aggregate(zeroValue)(seqOp, combOp)
// zeroValue是初始值
// seqOp用于计算一个分区中的结果
// combOp用户合并不同分区的结果
data.aggregate(treesIndices.map(_ => 0.0))(
    (aggregated, row) => treesIndices.map(idx => aggregated(idx) + row(idx)),
    (a, b) => treesIndices.map(idx => a(idx) + b(idx))
)
// 初始值是0数组, aggregated是累积结果, row是data中的每一行数据。意思就是将data中每一行相同位置的数据进行累加。之后(a, b)是合并不同分区的结果

boost方法

最关键的方法,用户训练得到模型。


def boost(
    input: RDD[LabeledPoint],
    validationInput: RDD[labeledPoint],
    boostingStrategy: OldBoostingStrategy,
    validate: Boolean,
    seed: Long): 
(Array[DecisionTreeRegressionModel], Array[Double]) = {
    // 验证boosting策略是否有效(只支持二分类以及回归任务,要求学习率(0,1])
    boostingStrategy.assertValid()
    // 初始化gradient boosting参数(基学习器是回归树)
    val numIterations = boostingStrategy.numIterations
    val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
    val baseLearnerWeights = new Array[Double](numIterations)
    val loss = boostingStrategy.loss
    val learningRate = boostingStrategy.learningRate
    // 初始化基学习器参数,基学习器是基于方差不纯度(variance impurity)的回归树
    val treeStrategy = boostingStrategy.treeStrategy.copy
    val validationTol = boostingStrategy.validationTol
    treeStrategy.algo = OldAlgo.Regression
    treeStrategy.impurity = OldVariance
    // 基学习器策略要求分类任务类别>=2, 不纯度测量方法为Gini或Entropy; 回归任务要求不纯度测量方法为variance。maxDepth>=0, maxBins>=2, minInstancesPerNode>=1, maxMemoryInMB<=10240, subsamplingRate满足(0,1]
    treeStrategy.assertValid()
    // 缓存训练数据(true/false标志用于判断是否需要unpersist)
    val persistedInput = if (input.getStorageLevel == StorageLevel.NONE)
    {
        input.persist(StorageLevel.MEMORY_AND_DISK)
        true
    } else {
        false
    }
    // 为训练数据和验证数据准备周期性的检查, 每进行一次迭代都更新对应的预测值和误差, 训练完成后删除所有的检查
    // 学习第一棵树(第一棵树的权重为1.0)
    val firstTree = new DecisionTreeRegressor().setSeed(seed)
    val firstTreeModel = firstTree.train(input, treeStrategy)
    val firstTreeWeight = 1.0
    baseLearners(0) = firstTreeModel
    baseLearnerWeights(0) = firstTreeWeight
    // 计算训练数据、验证数据的预测值与误差, 更新检查点
    computeInitialPredictionAndError(...)
    // 初始化最优验证误差以及最优的位置
    var bestValidateError = 
        if (validate) 
            validatePredError.values.mean() 
        else 0.0
    var bestM = 1
    // 主循环, 循环迭代学习, 梯度提升
    var m = 1 // 当前迭代
    var doneLearning = false // 是否早停
    while (m < numIterations && !doneLearning) {
        // 基于伪残差(pseudo-residuals, 负梯度)更新数据
        val data = predError.zip(input).map { case ((pred, _), point) =>
            LabeledPoint(-loss.gradient(pred, point.label), point.features)
        }
        val dt = new DecisionTreeRegressor().setSeed(seed + m)
        val model = dt.train(data, treeStrategy)
        baseLearners(m) = model
        // 权重设置为学习率,这种方法对于除了平方误差之外的损失函数是不正确的。权重应该针对每一种损失函数进行优化,但是这种方法尽管不是最优的,但是是合理的
        baseLearnerWeights(m) = learningRate
        // 计算预测和误差,更新检查
        // 验证集,用于判断早停和寻找最优的模型
        if (validate) {
            // 计算验证集预测和误差,更新验证集的检查
            validatePredError = updatePredictionError(...)
            val currentValidateError = validatePredError.values.mean()
            // 早停条件(开始的最优验证误差较大)
            // 1. 减小的误差小于validationTol(基学习器的参数),或者
            // 2. 验证误差增加(差小于0了), 模型可能过拟合
            // 返回对应最优验证误差的模型
            if (bestValidateError - currentValidateError < validationTol * Math.max(currentValidateError, 0.01)) {
                doneLearning = true
            } else if (currentValidateError < bestValidateError) {
                bestValidateError = currentValidateError
                bestM = m + 1
            }
        }
        m += 1
    }
    // 删除所有检查点
    // unpersist数据
    if (persistedInput) input.unpersist()
    // 返回模型
    if (validate) {
        (baseLearners.slice(0, bestM), baseLearnerWeights.slice(0, bestM))
    } else {
        (baseLearners, baseLearnerWeights)
    }
}

TODO

上一篇 下一篇

猜你喜欢

热点阅读