如何在Java应用里集成Spark MLlib训练好的模型做预测
前言
昨天媛媛說(shuō),你是不是很久沒(méi)寫(xiě)博客了。我說(shuō)上一篇1.26號(hào),昨天3.26號(hào),剛好兩個(gè)月,心中也略微有些愧疚。今天正好有個(gè)好朋友問(wèn),怎么在Java應(yīng)用里集成Spark MLlib訓(xùn)練好的模型。在StreamingPro里其實(shí)都有實(shí)際的使用例子,但是如果有一篇文章講述下,我覺(jué)得應(yīng)該能讓更多人獲得幫助
追本溯源
記得我之前吐槽過(guò)Spark MLlib的設(shè)計(jì),也是因?yàn)橐粋€(gè)朋友使用了spark MLlib的pipeline做訓(xùn)練,然后他把這個(gè)pipeline放到了spring boot里,結(jié)果做預(yù)測(cè)的時(shí)候奇慢無(wú)比,一條記錄inference需要30多秒。為什么會(huì)這么慢呢?原因是Spark MLlib 是以批處理為核心設(shè)計(jì)理念的。比如上面朋友遇到的坑是有一部分原因來(lái)源于word2vec的transform方法:
@Since("2.0.0")override def transform(dataset: Dataset[_]): DataFrame = {transformSchema(dataset.schema, logging = true)val vectors = wordVectors.getVectors.mapValues(vv => Vectors.dense(vv.map(_.toDouble))).map(identity) // mapValues doesn't return a serializable map (SI-7005)val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)val d = $(vectorSize)來(lái)一條數(shù)據(jù)(通常API應(yīng)用都是如此),他需要先獲得vectors(詞到vector的映射)對(duì)象,假設(shè)你有十萬(wàn)個(gè)詞,
def getVectors: Map[String, Array[Float]] = {wordIndex.map { case (word, ind) =>(word, wordVectors.slice(vectorSize * ind, vectorSize * ind + vectorSize))}}每次請(qǐng)求他都要做如上調(diào)用和計(jì)算。接著還需要把這些東西(這個(gè)可能就比較大了,幾百M(fèi)或者幾個(gè)G都有可能)廣播出去。
所以注定快不了。
把model集成到Java 服務(wù)里實(shí)例
假設(shè)你使用貝葉斯訓(xùn)練了一個(gè)模型,你需要保存下這個(gè)模型,保存的方式如下:
val nb = new NaiveBayes() //做些參數(shù)配置和訓(xùn)練過(guò)程 ..... //保存模型 nb.write.overwrite().save(path + "/" + modelIndex)接著,在你的Java/scala程序里,引入spark core,spark mllib等包。加載模型:
val model = NaiveBayesModel.load(tempPath)這個(gè)時(shí)候因?yàn)橐鲱A(yù)測(cè),我們?yōu)榱诵阅?#xff0c;不能直接調(diào)用model的transform方法,你仔細(xì)觀察發(fā)現(xiàn),我們需要通過(guò)反射調(diào)用兩個(gè)方法,就能實(shí)現(xiàn)分類(lèi)。第一個(gè)是predictRaw方法,該方法輸入一個(gè)向量,輸出也為一個(gè)向量。我們其實(shí)不需要向量,我們需要的是一個(gè)分類(lèi)的id。predictRaw 方法在model里,但是沒(méi)辦法直接調(diào)用,因?yàn)槭撬接械?#xff1a;
override protected def predictRaw(features: Vector): Vector = {$(modelType) match {case Multinomial =>multinomialCalculation(features)case Bernoulli =>bernoulliCalculation(features)case _ =>// This should never happen.throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")}}所以我們需要通過(guò)反射來(lái)完成:
val predictRaw = model.getClass.getMethod("predictRaw", classOf[Vector]).invoke(model, vec).asInstanceOf[Vector]現(xiàn)在我們已經(jīng)得到了predctRaw的結(jié)果,接著我們要用raw2probability 把向量轉(zhuǎn)化為一個(gè)概率分布,因?yàn)閟park 版本不同,該方法的簽名也略有變化,所以可能要做下版本適配:
val raw2probabilityMethod = if (sparkSession.version.startsWith("2.3")) "raw2probabilityInPlace" else "raw2probability" val raw2probability = model.getClass.getMethod(raw2probabilityMethod, classOf[Vector]).invoke(model, predictRaw).asInstanceOf[Vector]raw2probability 其實(shí)也還是一個(gè)向量,這個(gè)向量的長(zhǎng)度是分類(lèi)的數(shù)目,每個(gè)位置的值是概率。所以所以我們只要拿到最大的那個(gè)概率值所在的位置就行:
val categoryId = raw2probability.argmax這個(gè)時(shí)候categoryId 就是我們預(yù)測(cè)的分類(lèi)了。
截止到目前我們已經(jīng)完成了作為一個(gè)普通java/scala 方法的調(diào)用流程。如果我不想用在應(yīng)用程序里,而是放到spark 流式計(jì)算里呢?或者批處理也行,那么這個(gè)時(shí)候你只需要封裝一個(gè)UDF函數(shù)即可:
val models = sparkSession.sparkContext.broadcast(_model.asInstanceOf[ArrayBuffer[NaiveBayesModel]]) val f2 = (vec: Vector) => {models.value.map { model =>val predictRaw = model.getClass.getMethod("predictRaw", classOf[Vector]).invoke(model, vec).asInstanceOf[Vector]val raw2probability = model.getClass.getMethod(raw2probabilityMethod, classOf[Vector]).invoke(model, predictRaw).asInstanceOf[Vector]//model.getClass.getMethod("probability2prediction", classOf[Vector]).invoke(model, raw2probability).asInstanceOf[Vector]raw2probability}}sparkSession.udf.register(name , f2)上面的例子可以參考StreamingPro 中streaming.dsl.mmlib.algs.SQLNaiveBayes的代碼。不同的算法因?yàn)閮?nèi)部實(shí)現(xiàn)不同,我們使用起來(lái)也會(huì)略微有些區(qū)別。
總結(jié)
Spark MLlib學(xué)習(xí)了SKLearn里的transform和fit的概念,但是因?yàn)樵O(shè)計(jì)上還是遵循批處理的方式,實(shí)際部署后會(huì)有很大的性能瓶頸,不適合那種數(shù)據(jù)一條一條過(guò)來(lái)需要快速響應(yīng)的預(yù)測(cè)流程,所以需要調(diào)用一些內(nèi)部的API來(lái)完成最后的預(yù)測(cè)。
作者:祝威廉
鏈接:https://www.jianshu.com/p/3c038027ff61
來(lái)源:簡(jiǎn)書(shū)
簡(jiǎn)書(shū)著作權(quán)歸作者所有,任何形式的轉(zhuǎn)載都請(qǐng)聯(lián)系作者獲得授權(quán)并注明出處。
總結(jié)
以上是生活随笔為你收集整理的如何在Java应用里集成Spark MLlib训练好的模型做预测的全部?jī)?nèi)容,希望文章能夠幫你解決所遇到的問(wèn)題。
- 上一篇: Spark函数详解系列--RDD基本转换
- 下一篇: 关于spark的mllib学习总结(Ja