/*
 * Decompiled with CFR 0.152.
 */
package libsvm;

import java.util.Map;
import libsvm.LibSVM;
import libsvm.svm_parameter;
import net.sf.javaml.classification.evaluation.CrossValidation;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.core.Dataset;

public class GridSearch {
    private final LibSVM classifier;
    private final Dataset dataset;
    private final int folds;
    private final CrossValidation cv;
    private double bestAccuracy;
    private double bestC;
    private double bestGamma;
    private double[] C;
    private double[] gamma;
    private svm_parameter svmParameters;

    public GridSearch(LibSVM classifier, Dataset dataset, int folds) {
        this.classifier = classifier;
        this.dataset = dataset;
        this.folds = folds;
        this.cv = new CrossValidation(this.classifier);
        this.bestAccuracy = Double.MIN_VALUE;
        this.bestC = Double.MIN_VALUE;
        this.bestGamma = Double.MIN_VALUE;
    }

    public svm_parameter search(svm_parameter param, double[] C, double[] gamma) {
        this.C = C;
        this.gamma = gamma;
        this.svmParameters = param;
        if (param.kernel_type == 0 && gamma != null) {
            this.gamma = null;
        }
        if (this.gamma != null) {
            for (int i = 0; i < C.length; ++i) {
                for (int j = 0; j < gamma.length; ++j) {
                    this.crossValidation(i, j);
                }
            }
        } else {
            for (int i = 0; i < C.length; ++i) {
                this.crossValidation(i, null);
            }
        }
        param.C = this.bestC;
        if (this.gamma != null) {
            param.gamma = this.bestGamma;
        }
        return param;
    }

    private void crossValidation(Integer CIndex2, Integer gammaIndex) {
        this.svmParameters.C = this.C[CIndex2];
        if (gammaIndex != null) {
            this.svmParameters.gamma = this.gamma[gammaIndex];
        }
        this.classifier.setParameters(this.svmParameters);
        double averageAccuracy = 0.0;
        Map<Object, PerformanceMeasure> perfMap = null;
        perfMap = this.cv.crossValidation(this.dataset, this.folds);
        for (Object o : perfMap.keySet()) {
            PerformanceMeasure pm = perfMap.get(o);
            averageAccuracy += pm.getAccuracy();
        }
        averageAccuracy /= (double)perfMap.keySet().size();
        perfMap.clear();
        if (averageAccuracy > this.bestAccuracy) {
            this.bestAccuracy = averageAccuracy;
            this.bestC = this.C[CIndex2];
            if (gammaIndex != null) {
                this.bestGamma = this.gamma[gammaIndex];
            }
        }
    }
}

