spark mllib源码分析之DecisionTree与GBDT
我們?cè)谇懊娴奈恼轮v過(guò),在spark的實(shí)現(xiàn)中,樹模型的依賴鏈?zhǔn)荊BDT-> Decision Tree-> Random Forest,前面介紹了最基礎(chǔ)的Random Forest的實(shí)現(xiàn),在此基礎(chǔ)上我們介紹Decision Tree和GBDT的實(shí)現(xiàn)。
1. Decision Tree
1.1. DT的使用
官方給出的demo
// Train a DecisionTree model.
? ? // ?Empty categoricalFeaturesInfo indicates all features are continuous.
? ? val numClasses = 2
? ? val categoricalFeaturesInfo = Map[Int, Int]()
? ? val impurity = "gini"
? ? val maxDepth = 5
? ? val maxBins = 32
? ? val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
? ? ? impurity, maxDepth, maxBins)
其入?yún)⒊瞬恍枰付鋫€(gè)數(shù),其他參數(shù)與隨機(jī)森林類似,不再贅述
1.2 實(shí)現(xiàn)
主要的邏輯在DecisionTree.scala的run函數(shù)中
? /**
? ?* Method to train a decision tree model over an RDD
? ?* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
? ?* @return DecisionTreeModel that can be used for prediction
? ?*/
? @Since("1.2.0")
? def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
? ? // Note: random seed will not be used since numTrees = 1.
? ? val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
? ? val rfModel = rf.run(input)
? ? rfModel.trees(0)
? }
其實(shí)就是Random Forest 1棵樹的情形,同時(shí)特征不再抽樣。
2. Gradient Boosting Decision Tree
2.1. 算法簡(jiǎn)介
簡(jiǎn)稱GBDT,中文譯作梯度提升決策樹,估計(jì)沒(méi)幾個(gè)人聽(tīng)過(guò)。這里貼幾張之前介紹GBDT的PPT,簡(jiǎn)單回顧起算法原理,其中內(nèi)容來(lái)自wikipedia和”From RankNet to LambdaRank to LambdaMAR An Overview”這篇文章。
2.1.1. 算法原理
??
在這個(gè)算法里面,并沒(méi)有限定使用決策樹,如果使用決策樹,對(duì)應(yīng)里面的h應(yīng)該是樹結(jié)構(gòu),我們以決策樹說(shuō)明?
1. 使用原始樣本直接訓(xùn)練一棵樹?
循環(huán)訓(xùn)練?
2. 計(jì)算偽殘差,實(shí)際是梯度?
3. 將2中的偽殘差作為樣本的label去訓(xùn)練決策樹?
4. 這里是用最優(yōu)化方法計(jì)算葉子節(jié)點(diǎn)的輸出,而spark中直接使用的均值?
5. 計(jì)算當(dāng)輪模型的輸出,方法是上一輪的輸出加上本輪的預(yù)測(cè)值?
6. 循環(huán)結(jié)束后,輸出模型
2.1.2. 以二分類為例
?
?
?
2.2. GBDT使用
官方demo
// Train a GradientBoostedTrees model.
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
首先初始化訓(xùn)練參數(shù)boostingStrategy,然后設(shè)置其迭代次數(shù),分類樹,樹的最大深度,離散特征及其特征值數(shù),我們看下默認(rèn)的參數(shù)都有哪些
/**
? ?* Returns default configuration for the boosting algorithm
? ?* @param algo Learning goal. ?Supported:
? ?* ? ? ? ? ? ? [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
? ?* ? ? ? ? ? ? [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
? ?* @return Configuration for boosting algorithm
? ?*/
? @Since("1.3.0")
? def defaultParams(algo: Algo): BoostingStrategy = {
? ? val treeStrategy = Strategy.defaultStrategy(algo)
? ? treeStrategy.maxDepth = 3
? ? algo match {
? ? ? case Algo.Classification =>
? ? ? ? treeStrategy.numClasses = 2
? ? ? ? new BoostingStrategy(treeStrategy, LogLoss)
? ? ? case Algo.Regression =>
? ? ? ? new BoostingStrategy(treeStrategy, SquaredError)
? ? ? case _ =>
? ? ? ? throw new IllegalArgumentException(s"$algo is not supported by boosting.")
? ? }
? }
默認(rèn)樹的最大深度為3,如果是分類,為二分類,使用LogLoss;如果是回歸,使用SquareError,均方誤差。然后使用Strategy的默認(rèn)參數(shù)
? /**
? ?* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
? ?* @param algo Algo.Classification or Algo.Regression
? ?*/
? @Since("1.3.0")
? def defaultStrategy(algo: Algo): Strategy = algo match {
? ? case Algo.Classification =>
? ? ? new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
? ? ? ? numClasses = 2)
? ? case Algo.Regression =>
? ? ? new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
? ? ? ? numClasses = 0)
? }
Strategy的默認(rèn)參數(shù)也比較簡(jiǎn)單,其意義參見(jiàn)之前的文章。
2.3. GBDT實(shí)現(xiàn)
其實(shí)現(xiàn)開始于GradientBoostedTrees.scala的run函數(shù)
? /**
? ?* Method to train a gradient boosting model
? ?* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
? ?* @return a gradient boosted trees model that can be used for prediction
? ?*/
? @Since("1.2.0")
? def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
? ? val algo = boostingStrategy.treeStrategy.algo
? ? algo match {
? ? ? case Regression =>
? ? ? ? GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false)
? ? ? case Classification =>
? ? ? ? // Map labels to -1, +1 so binary classification can be treated as regression.
? ? ? ? val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
? ? ? ? GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false)
? ? ? case _ =>
? ? ? ? throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
? ? }
? }
從其注釋可以看到,spark GBDT只實(shí)現(xiàn)了二分類,并且二分類的class必須是0/1,其把0/1轉(zhuǎn)化成-1/+1的label,然后按回歸處理。
2.3.2. 數(shù)據(jù)結(jié)構(gòu)
2.3.2.1. LogLoss
在第二頁(yè)P(yáng)PT中我們給出了loss,spark使用的loss是σ=1,log前增加了系數(shù)2的情況
L(y,FN)=2log(1+e?2yFN)
L(y,FN)=2log(1+e?2yFN)
對(duì)應(yīng)梯度變成
g=4y/(1+e2yFm?1(x))
g=4y/(1+e2yFm?1(x))
其中m-1指的是在第m次迭代中,使用的是m-1次的預(yù)測(cè)值。注意到我們的PPT的第四頁(yè)的γ,其實(shí)是葉子節(jié)點(diǎn)的預(yù)測(cè)值,是通過(guò)最優(yōu)化得到的,而spark這里使用的是Random Forest的代碼,其impurity選擇的是variance,因此預(yù)測(cè)值是均值。
? @Since("1.2.0")
? override def gradient(prediction: Double, label: Double): Double = {
? ? - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
? }
? override private[mllib] def computeError(prediction: Double, label: Double): Double = {
? //loss
? ? val margin = 2.0 * label * prediction
? ? // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
? ? 2.0 * MLUtils.log1pExp(-margin)
? }
SquaredError比較簡(jiǎn)單,這里不再啰嗦了。
2.3.1. init
將傳入的參數(shù)轉(zhuǎn)成訓(xùn)練時(shí)的參數(shù),cache predError和validatePredError,并且按treeStrategy.getCheckpointInterval(default 10)建立checkpoint。這里代碼比較簡(jiǎn)單,不再贅述。
2.3.2. build the first tree
參照算法原理的第一步,訓(xùn)練了第一棵樹,并且將weight設(shè)為1,,然后計(jì)算錯(cuò)誤率。調(diào)用了computeInitialPredictionAndError函數(shù)
? /**
? ?* :: DeveloperApi ::
? ?* Compute the initial predictions and errors for a dataset for the first
? ?* iteration of gradient boosting.
? ?* @param data: training data.
? ?* @param initTreeWeight: learning rate assigned to the first tree.
? ?* @param initTree: first DecisionTreeModel.
? ?* @param loss: evaluation metric.
? ?* @return a RDD with each element being a zip of the prediction and error
? ?* ? ? ? ? corresponding to every sample.
? ?*/
? @Since("1.4.0")
? @DeveloperApi
? def computeInitialPredictionAndError(
? ? ? data: RDD[LabeledPoint],
? ? ? initTreeWeight: Double,
? ? ? initTree: DecisionTreeModel,
? ? ? loss: Loss): RDD[(Double, Double)] = {
? ? data.map { lp =>
? ? ? val pred = initTreeWeight * initTree.predict(lp.features)
? ? ? val error = loss.computeError(pred, lp.label)
? ? ? (pred, error)
? ? }
? }
其中預(yù)測(cè)值直接使用DT的predict來(lái)預(yù)測(cè),error使用loss的computeError函數(shù),我們上面有介紹。
2.3.3. 循環(huán)訓(xùn)練
2.3.3.1. 樣本處理
對(duì)應(yīng)算法的第2步,計(jì)算梯度,并且作為label更新樣本
val data = predError.zip(input).map { case ((pred, _), point) =>
? ? ? ? LabeledPoint(-loss.gradient(pred, point.label), point.features)
? ? ? }
2.3.3.2. 訓(xùn)練樹
對(duì)應(yīng)算法的第3和第4步,用第2步的樣本作為輸入,訓(xùn)練決策樹
val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// ? ? ? Technically, the weight should be optimized for the particular loss.
// ? ? ? However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
2.3.3.3. 計(jì)算模型輸出
實(shí)際調(diào)用updatePredictionError函數(shù),入?yún)⑹窃嫉臉颖?#xff0c;上一輪的錯(cuò)誤率(實(shí)際包含上一輪的模型輸出),本來(lái)的決策樹,學(xué)習(xí)率和loss計(jì)算對(duì)象。
? /**
? ?* :: DeveloperApi ::
? ?* Update a zipped predictionError RDD
? ?* (as obtained with computeInitialPredictionAndError)
? ?* @param data: training data.
? ?* @param predictionAndError: predictionError RDD
? ?* @param treeWeight: Learning rate.
? ?* @param tree: Tree using which the prediction and error should be updated.
? ?* @param loss: evaluation metric.
? ?* @return a RDD with each element being a zip of the prediction and error
? ?* ? ? ? ? corresponding to each sample.
? ?*/
? @Since("1.4.0")
? @DeveloperApi
? def updatePredictionError(
? ? data: RDD[LabeledPoint],
? ? predictionAndError: RDD[(Double, Double)],
? ? treeWeight: Double,
? ? tree: DecisionTreeModel,
? ? loss: Loss): RDD[(Double, Double)] = {
? ? val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
? ? ? iter.map { case (lp, (pred, error)) =>
? ? ? //計(jì)算本輪模型的預(yù)測(cè)值
? ? ? ? val newPred = pred + tree.predict(lp.features) * treeWeight
? ? ? ? //計(jì)算本輪誤差
? ? ? ? val newError = loss.computeError(newPred, lp.label)
? ? ? ? //newPred是累計(jì),包含至本輪的模型輸出
? ? ? ? (newPred, newError)
? ? ? }
? ? }
? ? newPredError
? }
代碼中使用到的函數(shù)我們之前都有介紹。
2.3.3.3. validation(early stop)
類似計(jì)算錯(cuò)誤率,只是樣本使用validationInput,看平均誤差是否減少,如果不能使誤差減小就結(jié)束訓(xùn)練,相當(dāng)于出現(xiàn)過(guò)擬合了;如果能,就繼續(xù)訓(xùn)練,并且記錄最好的模型的index。這里一次誤差變大就結(jié)束訓(xùn)練比較武斷,最好應(yīng)該有一定的閾值,避免單次訓(xùn)練的波動(dòng)。代碼比較簡(jiǎn)單,就不放了。
2.3.3.4. 訓(xùn)練收尾
訓(xùn)練完成后,根據(jù)記錄的最優(yōu)模型的index,構(gòu)造GradientBoostedTreesModel。
3.結(jié)語(yǔ)
從上面的分析可以看到,由于spark在Random Forest特征方面的限制,以及GBDT實(shí)現(xiàn)中直接使用均值作為葉子節(jié)點(diǎn)的輸出值,early stop等,spark在樹模型上的精度可能會(huì)差一點(diǎn),實(shí)際使用的話,最好與其他實(shí)現(xiàn)比較后決定是否使用。
---------------------?
原文:https://blog.csdn.net/snaillup/article/details/74207929?
總結(jié)
以上是生活随笔為你收集整理的spark mllib源码分析之DecisionTree与GBDT的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: spark mllib源码分析之随机森林
- 下一篇: Tensorflow从入门到精通之:Te