使用weka进行Cross-validation实验
Generating cross-validation folds (Java approach)
文獻:
http://weka.wikispaces.com/Generating+cross-validation+folds+%28Java+approach%29
This article describes how to generate train/test splits for?cross-validation?using the Weka API directly.?
The following variables are given:
?Instances data =? ...;?? // contains the full dataset we wann create train/test sets from
?int seed = ...;????????? // the seed for randomizing the data
?int folds = ...;???????? // the number of folds to generate, >=2
?Randomize the data
First, randomize your data:
?Random rand = new Random(seed);?? // create seeded number generator
?randData = new Instances(data); ??// create copy of original data
?randData.randomize(rand);???????? // randomize data with number generator
In case your data has a nominal class and you wanna perform stratified cross-validation:
?randData.stratify(folds);
?Generate the folds
?Single run
Next thing that we have to do is creating the train and the test set:
?for (int n = 0; n < folds; n++) {
?? Instances train = randData.trainCV(folds, n);
?? Instances test = randData.testCV(folds, n);
?
?? // further processing, classification, etc.
?? ...
?}
Note:
- the above code is used by the?weka.filters.supervised.instance.StratifiedRemoveFolds?filter
- the?weka.classifiers.Evaluation?class and the Explorer/Experimenter would use this method for obtaining the train set:
?Instances train = randData.trainCV(folds, n, rand);
?Multiple runs
The example above only performs one run of a cross-validation. In case you want to run 10 runs of 10-fold cross-validation, use the following loop:
?Instances data = ...;? // our dataset again, obtained from somewhere
?int runs = 10;
?for (int i = 0; i < runs; i++) {
?? seed = i+1;? // every run gets a new, but defined seed value
?
?? // see: randomize the data
?? ...
?
?? // see: generate the folds
?? ...
?}
?
?
一個簡單的小實驗:
繼續(xù)對上一節(jié)中的紅酒和白酒進行分類。分類器沒有變化,只是增加了重復(fù)試驗過程
package assignment2;import weka.core.Instances;import weka.core.converters.ConverterUtils.DataSource;import weka.core.Utils;import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.trees.J48;import weka.filters.Filter;import weka.filters.unsupervised.attribute.Remove;import java.io.FileReader;import java.util.Random;public class cv_rw {public static Instances getFileInstances(String filename) throws Exception{FileReader frData =new FileReader(filename);Instances data = new Instances(frData);int length= data.numAttributes();String[] options = new String[2];options[0]="-R";options[1]=Integer.toString(length);Remove remove =new Remove();remove.setOptions(options);remove.setInputFormat(data);Instances newData= Filter.useFilter(data, remove);return newData;}public static void main(String[] args) throws Exception {// loads data and set class index Instances data = getFileInstances("D://Weka_tutorial//WineQuality//RedWhiteWine.arff");// System.out.println(instances); data.setClassIndex(data.numAttributes()-1);// classifier// String[] tmpOptions;// String classname;// tmpOptions = Utils.splitOptions(Utils.getOption("W", args));// classname = tmpOptions[0];// tmpOptions[0] = "";// Classifier cls = (Classifier) Utils.forName(Classifier.class, classname, tmpOptions);//// // other options// int runs = Integer.parseInt(Utils.getOption("r", args));//重復(fù)試驗// int folds = Integer.parseInt(Utils.getOption("x", args));int runs=1;int folds=10;J48 j48= new J48();// j48.buildClassifier(instances);// perform cross-validationfor (int i = 0; i < runs; i++) {// randomize dataint seed = i + 1;Random rand = new Random(seed);Instances randData = new Instances(data);randData.randomize(rand);// if (randData.classAttribute().isNominal()) //沒看懂這里什么意思,往高手回復(fù),萬分感謝// randData.stratify(folds); Evaluation eval = new Evaluation(randData);for (int n = 0; n < folds; n++) {Instances train = randData.trainCV(folds, n);Instances test = randData.testCV(folds, n);// the above code is used by the StratifiedRemoveFolds filter, the// code below by the Explorer/Experimenter:// Instances train = randData.trainCV(folds, n, rand);// build and evaluate classifier Classifier j48Copy = Classifier.makeCopy(j48);j48Copy.buildClassifier(train);eval.evaluateModel(j48Copy, test);}// output evaluation System.out.println();System.out.println("=== Setup run " + (i+1) + " ===");System.out.println("Classifier: " + j48.getClass().getName());System.out.println("Dataset: " + data.relationName());System.out.println("Folds: " + folds);System.out.println("Seed: " + seed);System.out.println();System.out.println(eval.toSummaryString("=== " + folds + "-fold Cross-validation run " + (i+1) + "===", false));}}}?
?
運行程序得到實驗結(jié)果:
?
=== Setup run 1 ===
Classifier: weka.classifiers.trees.J48
Dataset: RedWhiteWine-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.instance.Randomize-S42-weka.filters.unsupervised.attribute.Remove-R13
Folds: 10
Seed: 1
?
=== 10-fold Cross-validation run 1===
Correctly Classified Instances? ??????6415?????????????? 98.7379 %
Incorrectly Classified Instances??????? 82??????????????? 1.2621 %
Kappa statistic????????????????????????? 0.9658
Mean absolute error????????????????????? 0.0159
Root mean squared error????????????????? 0.1109
Relative absolute error????????????????? 4.2898 %
Root relative squared error???????????? 25.7448 %
Total Number of Instances???????????? 6497? ? ?
轉(zhuǎn)載于:https://www.cnblogs.com/7899-89/p/3667330.html
總結(jié)
以上是生活随笔為你收集整理的使用weka进行Cross-validation实验的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 关于java类型数据组的调用
- 下一篇: JavaScript学习总结(五)——J