关于spark的mllib学习总结(Java版)
本篇博客主要講述如何利用spark的mliib構(gòu)建機(jī)器學(xué)習(xí)模型并預(yù)測新的數(shù)據(jù),具體的流程如下圖所示:?
加載數(shù)據(jù)
對于數(shù)據(jù)的加載或保存,mllib提供了MLUtils包,其作用是Helper methods to load,save and pre-process data used in MLLib.博客中的數(shù)據(jù)是采用spark中提供的數(shù)據(jù)sample_libsvm_data.txt,其有一百個數(shù)據(jù)樣本,658個特征。具體的數(shù)據(jù)形式如圖所示:?
加載libsvm
JavaRDD<LabeledPoint> lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();LabeledPoint數(shù)據(jù)類型是對應(yīng)與libsvmfile格式文件, 具體格式為:?
Lable(double類型),vector(Vector類型)
轉(zhuǎn)化dataFrame數(shù)據(jù)類型
JavaRDD<Row> jrow = lpdata.map(new LabeledPointToRow()); StructType schema = new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),new StructField("features", new VectorUDT(), false, Metadata.empty()),}); SQLContext jsql = new SQLContext(sc); DataFrame df = jsql.createDataFrame(jrow, schema);DataFrame:DataFrame是一個以命名列方式組織的分布式數(shù)據(jù)集。在概念上,它跟關(guān)系型數(shù)據(jù)庫中的一張表或者1個Python(或者R)中的data frame一樣,但是比他們更優(yōu)化。DataFrame可以根據(jù)結(jié)構(gòu)化的數(shù)據(jù)文件、hive表、外部數(shù)據(jù)庫或者已經(jīng)存在的RDD構(gòu)造。
SQLContext:spark sql所有功能的入口是SQLContext類,或者SQLContext的子類。為了創(chuàng)建一個基本的SQLContext,需要一個SparkContext。
特征提取
特征歸一化處理
StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true); DataFrame scalerDF = scaler.fit(df).transform(df); scaler.save(this.scalerModelPath);利用卡方統(tǒng)計做特征提取
ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures"); ChiSqSelectorModel chiModel = selector.fit(scalerDF); DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures"); chiModel.save(this.featureSelectedModelPath);訓(xùn)練機(jī)器學(xué)習(xí)模型(以SVM為例)
//轉(zhuǎn)化為LabeledPoint數(shù)據(jù)類型, 訓(xùn)練模型 JavaRDD<Row> selectedrows = selectedDF.javaRDD(); JavaRDD<LabeledPoint> trainset = selectedrows.map(new RowToLabel());//訓(xùn)練SVM模型, 并保存 int numIteration = 200; SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration); model.clearThreshold(); model.save(sc, this.mlModelPath);// LabeledPoint數(shù)據(jù)類型轉(zhuǎn)化為Row static class LabeledPointToRow implements Function<LabeledPoint, Row> {public Row call(LabeledPoint p) throws Exception {double label = p.label();Vector vector = p.features();return RowFactory.create(label, vector);}}//Rows數(shù)據(jù)類型轉(zhuǎn)化為LabeledPoint static class RowToLabel implements Function<Row, LabeledPoint> {public LabeledPoint call(Row r) throws Exception {Vector features = r.getAs(1);double label = r.getDouble(0);return new LabeledPoint(label, features);}}測試新的樣本
測試新的樣本前,需要將樣本做數(shù)據(jù)的轉(zhuǎn)化和特征提取的工作,所有剛剛訓(xùn)練模型的過程中,除了保存機(jī)器學(xué)習(xí)模型,還需要保存特征提取的中間模型。具體代碼如下:
//初始化spark SparkConf conf = new SparkConf().setAppName("SVM").setMaster("local"); conf.set("spark.testing.memory", "2147480000"); SparkContext sc = new SparkContext(conf);//加載測試數(shù)據(jù) JavaRDD<LabeledPoint> testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();//轉(zhuǎn)化DataFrame數(shù)據(jù)類型 JavaRDD<Row> jrow =testData.map(new LabeledPointToRow());StructType schema = new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),new StructField("features", new VectorUDT(), false, Metadata.empty()),}); SQLContext jsql = new SQLContext(sc); DataFrame df = jsql.createDataFrame(jrow, schema);//數(shù)據(jù)規(guī)范化 StandardScaler scaler = StandardScaler.load(this.scalerModelPath); DataFrame scalerDF = scaler.fit(df).transform(df);//特征選取 ChiSqSelectorModel chiModel = ChiSqSelectorModel.load( this.featureSelectedModelPath); DataFrame selectedDF = chiModel.transform(scalerDF).select("label", "selectedFeatures");測試數(shù)據(jù)集
SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath); JavaRDD<Tuple2<Double, Double>> predictResult = testset.map(new Prediction(svmmodel)) ; predictResult.collect();static class Prediction implements Function<LabeledPoint, Tuple2<Double , Double>> {SVMModel model;public Prediction(SVMModel model){this.model = model;}public Tuple2<Double, Double> call(LabeledPoint p) throws Exception {Double score = model.predict(p.features());return new Tuple2<Double , Double>(score, p.label());}}計算準(zhǔn)確率
double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count(); System.out.println(accuracy);static class PredictAndScore implements Function<Tuple2<Double, Double>, Boolean> {public Boolean call(Tuple2<Double, Double> t) throws Exception {double score = t._1();double label = t._2();System.out.print("score:" + score + ", label:"+ label);if(score >= 0.0 && label >= 0.0) return true;else if(score < 0.0 && label < 0.0) return true;else return false;}}具體的代碼,放在我的github上:https://github.com/Quincy1994/MachineLearning/
總結(jié)
以上是生活随笔為你收集整理的关于spark的mllib学习总结(Java版)的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 如何在Java应用里集成Spark ML
- 下一篇: 从Chrome源码看audio/vide