java实现一个简单的机器学习的demo

2017-01-08  本文已影响919人  freelands

因为最近忙着一个比赛,想用机器学习的方法来实现,因为我们用的是Java,所以就用到了weka,weka是java关于机器学习的库,weka的jar包可以去官网下载.

1.准备数据

在项目里创建一个txt文件,然后把下面的数据放进去

@relation weather  
  
@attribute outlook {sunny, overcast, rainy}  
@attribute temperature numeric  
@attribute humidity numeric  
@attribute windy {TRUE, FALSE}  
@attribute play {yes, no}  
  
@data  
sunny,85,85,FALSE,no  
sunny,80,90,TRUE,no  
overcast,83,86,FALSE,yes  
rainy,70,96,FALSE,yes  
rainy,68,80,FALSE,yes  
rainy,65,70,TRUE,no  
overcast,64,65,TRUE,yes  
sunny,72,95,FALSE,no  
sunny,69,70,FALSE,yes  
rainy,75,80,FALSE,yes  
sunny,75,70,TRUE,yes  
overcast,72,90,TRUE,yes  
overcast,81,75,FALSE,yes  
rainy,71,91,TRUE,no  

这里的数据主要是样本数据,一部分用于训练集,一部分用于测试集。

训练集

我的理解就是,选择适当的模型后,通过一部分数据进行训练,然后这个模型就有了初始的一个决策逻辑,当然这都和你的数据有关。

测试集

我觉得就是,在模型有了决策逻辑之后,然后通过一些数据,进行测试,查看模型的准确度。

2.选择合适的模型


import java.io.BufferedReader;  
import java.io.FileNotFoundException;  
import java.io.FileReader;  
import weka.classifiers.Classifier;  
import weka.classifiers.Evaluation;  
import weka.classifiers.evaluation.NominalPrediction;  
import weka.classifiers.rules.DecisionTable;  
import weka.classifiers.rules.PART;  
import weka.classifiers.trees.DecisionStump;  
import weka.classifiers.trees.J48;  
import weka.core.FastVector;  
import weka.core.Instances;  
   
public class WekaTest {  
    public static BufferedReader readDataFile(String filename) {  
        BufferedReader inputReader = null;  
   
        try {  
            inputReader = new BufferedReader(new FileReader(filename));  
        } catch (FileNotFoundException ex) {  
            System.err.println("File not found: " + filename);  
        }  
   
        return inputReader;  
    }  
   
    public static Evaluation classify(Classifier model,  
            Instances trainingSet, Instances testingSet) throws Exception {  
        Evaluation evaluation = new Evaluation(trainingSet);  
   
        model.buildClassifier(trainingSet);  
        evaluation.evaluateModel(model, testingSet);  
   
        return evaluation;  
    }  
   
    public static double calculateAccuracy(FastVector predictions) {  
        double correct = 0;  
   
        for (int i = 0; i < predictions.size(); i++) {  
            NominalPrediction np = (NominalPrediction) predictions.elementAt(i);  
            if (np.predicted() == np.actual()) {  
                correct++;  
            }  
        }  
   
        return 100 * correct / predictions.size();  
    }  
   
    public static Instances[][] crossValidationSplit(Instances data, int numberOfFolds) {  
        Instances[][] split = new Instances[2][numberOfFolds];  
   
        for (int i = 0; i < numberOfFolds; i++) {  
            split[0][i] = data.trainCV(numberOfFolds, i);  
            split[1][i] = data.testCV(numberOfFolds, i);  
        }  
   
        return split;  
    }  
   
    public static void main(String[] args) throws Exception {  
        BufferedReader datafile = readDataFile("weather.txt");  
   
        Instances data = new Instances(datafile);  
        data.setClassIndex(data.numAttributes() - 1);  
   
        // Do 10-split cross validation  
        Instances[][] split = crossValidationSplit(data, 10);  
   
        // Separate split into training and testing arrays  
        Instances[] trainingSplits = split[0];  
        Instances[] testingSplits = split[1];  
   
        // Use a set of classifiers  
        Classifier[] models = {   
                new J48(), // a decision tree  
                new PART(),   
                new DecisionTable(),//decision table majority classifier  
                new DecisionStump() //one-level decision tree  
        };  
   
        // Run for each model  
        for (int j = 0; j < models.length; j++) {  
   
            // Collect every group of predictions for current model in a FastVector  
            FastVector predictions = new FastVector();  
   
            // For each training-testing split pair, train and test the classifier  
            for (int i = 0; i < trainingSplits.length; i++) {  
                Evaluation validation = classify(models[j], trainingSplits[i], testingSplits[i]);  
   
                predictions.appendElements(validation.predictions());  
   
                // Uncomment to see the summary for each training-testing pair.  
                //System.out.println(models[j].toString());  
            }  
   
            // Calculate overall accuracy of current classifier on all splits  
            double accuracy = calculateAccuracy(predictions);  
   
            // Print current classifier's name and accuracy in a complicated,  
            // but nice-looking way.  
            System.out.println("Accuracy of " + models[j].getClass().getSimpleName() + ": "  
                    + String.format("%.2f%%", accuracy)  
                    + "\n---------------------------------");  
        }  
   
    }  
}  

3.运行结果

Accuracy of J48: 50.00%
---------------------------------
Accuracy of PART: 50.00%
---------------------------------
Accuracy of DecisionTable: 64.29%
---------------------------------
Accuracy of DecisionStump: 21.43%
---------------------------------
上一篇下一篇

猜你喜欢

热点阅读