Java机器学习库ML之九交叉验证法(Cross Validation)
交叉驗證(Cross Validation,CV)是用來驗證分類器的性能一種統計分析方法,基本思想是把在某種意義下將原始數據(dataset)進行分組,一部分做為訓練集(train set),另一部分做為驗證集(validation set)。首先用訓練集對分類器進行訓練,再利用驗證集來測試訓練得到的模型(model),以此來做為評價分類器的性能指標。常見CV的方法如下:
1)Hold-Out Method
將原始數據隨機分為兩組,一組做為訓練集,一組做為驗證集,利用訓練集訓練分類器,然后利用驗證集驗證模型,記錄最后的分類準確率為此Hold-OutMethod下分類器的性能指標.此種方法的好處的處理簡單,只需隨機把原始數據分為兩組即可,其實嚴格意義來說Hold-Out Method并不能算是CV,因為這種方法沒有達到交叉的思想,由于是隨機的將原始數據分組,所以最后驗證集分類準確率的高低與原始數據的分組有很大的關系,所以這種方法得到的結果其實并不具有說服性.
2)K-fold Cross Validation(記為K-CV)
將原始數據分成K組(一般是均分),將每個子集數據分別做一次驗證集,其余的K-1組子集數據作為訓練集,這樣會得到K個模型,用這K個模型最終的驗證集的分類準確率的平均數作為此K-CV下分類器的性能指標.K一般大于等于2,實際操作時一般從3開始取,只有在原始數據集合數據量小的時候才會嘗試取2.K-CV可以有效的避免過學習以及欠學習狀態的發生,最后得到的結果也比較具有說服性.
?
3).Leave-One-Out Cross Validation(記為LOO-CV)
如果設原始數據有N個樣本,那么LOO-CV就是N-CV,即每個樣本單獨作為驗證集,其余的N-1個樣本作為訓練集,所以LOO-CV會得到N個模型,用這N個模型最終的驗證集的分類準確率的平均數作為此下LOO-CV分類器的性能指標.相比于前面的K-CV,LOO-CV有兩個明顯的優點:
①a.每一回合中幾乎所有的樣本皆用于訓練模型,因此最接近原始樣本的分布,這樣評估所得的結果比較可靠。
②b.實驗過程中沒有隨機因素會影響實驗數據,確保實驗過程是可以被復制的。
但LOO-CV的缺點則是計算成本高,因為需要建立的模型數量與原始數據樣本數量相同,當原始數據樣本數量相當多時,LOO-CV在實作上便有困難幾乎就是不顯示,除非每次訓練分類器得到模型的速度很快,或是可以用并行化計算減少計算所需的時間.
ML庫中的交叉驗證法,應該是采用K折交叉驗證法,代碼示例如下:
/*** This file is part of the Java Machine Learning Library* * The Java Machine Learning Library is free software; you can redistribute it and/or modify* it under the terms of the GNU General Public License as published by* the Free Software Foundation; either version 2 of the License, or* (at your option) any later version.* * The Java Machine Learning Library is distributed in the hope that it will be useful,* but WITHOUT ANY WARRANTY; without even the implied warranty of* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the* GNU General Public License for more details.* * You should have received a copy of the GNU General Public License* along with the Java Machine Learning Library; if not, write to the Free Software* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA* * Copyright (c) 2006-2012, Thomas Abeel* * Project: http://java-ml.sourceforge.net/* */ package com.gddx;import java.io.File; import java.util.Map;import net.sf.javaml.classification.Classifier; import net.sf.javaml.classification.KNearestNeighbors; import net.sf.javaml.classification.evaluation.CrossValidation; import net.sf.javaml.classification.evaluation.PerformanceMeasure; import net.sf.javaml.core.Dataset; import net.sf.javaml.tools.data.FileHandler;/*** This tutorial shows how you can do cross-validation with Java-ML* * @author Thomas Abeel* */ public class TutorialCrossValidation {/*** Default cross-validation with little options.*/public static void main(String[] args) throws Exception {/* Load data */Dataset data = FileHandler.loadDataset(new File("D:\\tmp\\javaml-0.1.7-src\\UCI-small\\iris\\iris.data"), 4, ",");//Dataset data = FileHandler.loadDataset(new File("D:\\tmp\\train_features_npequal.txt"), 12, "\\s+");/* Construct KNN classifier */Classifier knn = new KNearestNeighbors(2);/* Construct new cross validation instance with the KNN classifier */CrossValidation cv = new CrossValidation(knn);/* Perform 5-fold cross-validation on the data set */Map<Object, PerformanceMeasure> p = cv.crossValidation(data);System.out.println("Accuracy=" + p.get("Iris-virginica").getAccuracy());System.out.println(p);} }
總結
以上是生活随笔為你收集整理的Java机器学习库ML之九交叉验证法(Cross Validation)的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: Java机器学习库ML之八关于模型迭代训
- 下一篇: Java机器学习库ML之十一线性SVM