/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.classification.evaluation;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

public class CrossValidation {
    private Classifier classifier;

    public CrossValidation(Classifier classifier) {
        this.classifier = classifier;
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset data, int numFolds, Random rg) {
        Dataset[] folds = data.folds(numFolds, rg);
        HashMap<Object, PerformanceMeasure> out = new HashMap<Object, PerformanceMeasure>();
        for (Object e : data.classes()) {
            out.put(e, new PerformanceMeasure());
        }
        for (int i = 0; i < numFolds; ++i) {
            Dataset dataset = folds[i];
            DefaultDataset training = new DefaultDataset();
            for (int j = 0; j < numFolds; ++j) {
                if (j == i) continue;
                training.addAll(folds[j]);
            }
            this.classifier.buildClassifier(training);
            for (Instance instance : dataset) {
                Object prediction = this.classifier.classify(instance);
                if (instance.classValue().equals(prediction)) {
                    for (Object o : out.keySet()) {
                        if (o.equals(instance.classValue())) {
                            ((PerformanceMeasure)out.get(o)).tp += 1.0;
                            continue;
                        }
                        ((PerformanceMeasure)out.get(o)).tn += 1.0;
                    }
                    continue;
                }
                for (Object o : out.keySet()) {
                    if (prediction.equals(o)) {
                        ((PerformanceMeasure)out.get(o)).fp += 1.0;
                        continue;
                    }
                    if (o.equals(instance.classValue())) {
                        ((PerformanceMeasure)out.get(o)).fn += 1.0;
                        continue;
                    }
                    ((PerformanceMeasure)out.get(o)).tn += 1.0;
                }
            }
        }
        return out;
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset data, int folds) {
        return this.crossValidation(data, folds, new Random(System.currentTimeMillis()));
    }

    public Map<Object, PerformanceMeasure> crossValidation(Dataset data) {
        return this.crossValidation(data, 10);
    }
}

