本帖最后由 pig2 于 2015-1-6 14:18 编辑
问题导读
1.牛顿法有哪些优点体现?
2.L-BFGS算法中使用到的正则化方法是什么?
概要本文就拟牛顿法L-BFGS的由来做一个简要的回顾,然后就其在spark mllib中的实现进行源码走读。
拟牛顿法数学原理
代码实现L-BFGS算法中使用到的正则化方法是SquaredL2Updater。 算法实现上使用到了由scalanlp的成员项目breeze库中的BreezeLBFGS函数,mllib中自定义了BreezeLBFGS所需要的DiffFunctions.
runLBFGS函数的源码实现如下
- def
- runLBFGS(
- data: RDD[(Double, Vector)],
- gradient: Gradient,
- updater: Updater,
- numCorrections: Int,
- convergenceTol: Double,
- maxNumIterations: Int,
- regParam: Double,
- initialWeights: Vector): (Vector, Array[Double]) = {
-
-
- val lossHistory = new ArrayBuffer[Double](maxNumIterations)
-
- val numExamples = data.count()
-
- val costFun =
- new CostFun(data, gradient, updater, regParam, numExamples)
-
- val lbfgs = new BreezeLBFGS[BDV[Double]](maxNumIterations, numCorrections, convergenceTol)
-
- val states =
- lbfgs.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
-
- /**
- * NOTE: lossSum and loss is computed using the weights from the previous iteration
- * and regVal is the regularization value computed in the previous iteration as well.
- */
- var state = states.next()
- while(states.hasNext) {
- lossHistory.append(state.value)
- state = states.next()
- }
- lossHistory.append(state.value)
- val weights = Vectors.fromBreeze(state.x)
-
- logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
- lossHistory.takeRight(10).mkString(", ")))
-
- (weights, lossHistory.toArray)
- }
复制代码
costFun函数是算法实现中的重点 - private
-
- class CostFun(
- data: RDD[(Double, Vector)],
- gradient: Gradient,
- updater: Updater,
- regParam: Double,
- numExamples: Long) extends DiffFunction[BDV[Double]] {
-
- private var i = 0
-
- override def calculate(weights: BDV[Double]) = {
- // Have a local copy to avoid the serialization of CostFun object which is not serializable.
- val localData = data
- val localGradient = gradient
-
- val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
- seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
- val l = localGradient.compute(
- features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
- (grad, loss + l)
- },
- combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
- (grad1 += grad2, loss1 + loss2)
- })
-
- /**
- * regVal is sum of weight squares if it's L2 updater;
- * for other updater, the same logic is followed.
- */
- val regVal = updater.compute(
- Vectors.fromBreeze(weights),
- Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
-
- val loss = lossSum / numExamples + regVal
- /**
- * It will return the gradient part of regularization using updater.
- *
- * Given the input parameters, the updater basically does the following,
- *
- * w' = w - thisIterStepSize * (gradient + regGradient(w))
- * Note that regGradient is function of w
- *
- * If we set gradient = 0, thisIterStepSize = 1, then
- *
- * regGradient(w) = w - w'
- *
- * TODO: We need to clean it up by separating the logic of regularization out
- * from updater to regularizer.
- */
- // The following gradientTotal is actually the regularization part of gradient.
- // Will add the gradientSum computed from the data with weights in the next step.
- val gradientTotal = weights - updater.compute(
- Vectors.fromBreeze(weights),
- Vectors.dense(new Array[Double](weights.size)), 1, 1, regParam)._1.toBreeze
-
- // gradientTotal = gradientSum / numExamples + gradientTotal
- axpy(1.0 / numExamples, gradientSum, gradientTotal)
-
- i += 1
-
- (loss, gradientTotal)
- }
- }
-
- }
复制代码
相关内容
Apache Spark源码走读之1 -- Spark论文阅读笔记
Apache Spark源码走读之2 -- Job的提交与运行
Apache Spark源码走读之3-- Task运行期之函数调用关系分析
Apache Spark源码走读之4 -- DStream实时流数据处理
Apache Spark源码走读之5-- DStream处理的容错性分析
Apache Spark源码走读之6-- 存储子系统分析
Apache Spark源码走读之7 -- Standalone部署方式分析
Apache Spark源码走读之8 -- Spark on Yarn
Apache Spark源码走读之9 -- Spark源码编译
Apache Spark源码走读之10 -- 在YARN上运行SparkPi
Apache Spark源码走读之11 -- sql的解析与执行
Apache Spark源码走读之12 -- Hive on Spark运行环境搭建
Apache Spark源码走读之13 -- hiveql on spark实现详解
Apache Spark源码走读之14 -- Graphx实现剖析
Apache Spark源码走读之15 -- Standalone部署模式下的容错性分析
Apache Spark源码走读之16 -- spark repl实现详解
Apache Spark源码走读之17 -- 如何进行代码跟读
Apache Spark源码走读之18 -- 使用Intellij idea调试Spark源码
Apache Spark源码走读之19 -- standalone cluster模式下资源的申请与释放
Apache Spark源码走读之20 -- ShuffleMapTask计算结果的保存与读取
Apache Spark源码走读之21 -- WEB UI和Metrics初始化及数据更新过程分析
Apache Spark源码走读之22 -- 浅谈mllib中线性回归的算法实现
Apache Spark源码走读之23 -- Spark MLLib中拟牛顿法L-BFGS的源码实现
Apache Spark源码走读之24 -- Sort-based Shuffle的设计与实现
本文转自徽沪一郎http://www.cnblogs.com/hseagle/p/3927887.html
|