/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.featureselection.scoring;

import be.abeel.io.Copier;
import java.util.Random;
import net.sf.javaml.classification.evaluation.PerformanceMeasure;
import net.sf.javaml.classification.tree.RandomTree;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.featureselection.FeatureScoring;
import net.sf.javaml.tools.DatasetTools;
import net.sf.javaml.utils.ArrayUtils;
import net.sf.javaml.utils.MathUtils;

public class RandomForestAttributeEvaluation
implements FeatureScoring {
    private int numTrees;
    private Object positiveClass;
    private int k;
    private Random rg;
    private int numPerturbations;
    private double[] importance;

    public void setK(int k) {
        this.k = k;
    }

    public void setPerturbations(int p) {
        this.numPerturbations = p;
    }

    public RandomForestAttributeEvaluation(int numTrees, Object positiveClass, Random rg) {
        this.rg = rg;
        this.numTrees = numTrees;
        this.positiveClass = positiveClass;
        this.k = 5;
        this.numPerturbations = 1;
    }

    @Override
    public void build(Dataset data) {
        Copier instCopier = new Copier();
        int tp = 0;
        int fp = 0;
        int fn = 0;
        int tn = 0;
        int[][] tpR = new int[data.noAttributes()][this.numPerturbations];
        int[][] fpR = new int[data.noAttributes()][this.numPerturbations];
        int[][] tnR = new int[data.noAttributes()][this.numPerturbations];
        int[][] fnR = new int[data.noAttributes()][this.numPerturbations];
        for (int k = 0; k < data.noAttributes(); ++k) {
            tpR[k] = new int[this.numPerturbations];
            fpR[k] = new int[this.numPerturbations];
            tnR[k] = new int[this.numPerturbations];
            fnR[k] = new int[this.numPerturbations];
        }
        for (int i = 0; i < this.numTrees; ++i) {
            RandomTree tree = new RandomTree(this.k, this.rg);
            Dataset sample = DatasetTools.bootstrap(data, data.size(), this.rg);
            tree.buildClassifier(sample);
            DefaultDataset outOfBag = new DefaultDataset();
            outOfBag.addAll(data);
            outOfBag.removeAll(sample);
            for (Instance inst : outOfBag) {
                Object predClass = tree.classify(inst);
                if (predClass.equals(this.positiveClass)) {
                    if (inst.classValue().equals(this.positiveClass)) {
                        ++tp;
                        continue;
                    }
                    ++fp;
                    continue;
                }
                if (inst.classValue().equals(this.positiveClass)) {
                    ++fn;
                    continue;
                }
                ++tn;
            }
            for (int k = 0; k < data.noAttributes(); ++k) {
                for (int j = 0; j < this.numPerturbations; ++j) {
                    DefaultDataset perturbed = new DefaultDataset();
                    for (Instance inst : outOfBag) {
                        Instance per = (Instance)instCopier.copy((Object)inst);
                        per.put(k, Math.random());
                        perturbed.add(per);
                    }
                    for (Instance inst : perturbed) {
                        Object predClass = tree.classify(inst);
                        if (predClass.equals(this.positiveClass)) {
                            if (inst.classValue().equals(this.positiveClass)) {
                                int[] nArray = tpR[k];
                                int n = j;
                                nArray[n] = nArray[n] + 1;
                                continue;
                            }
                            int[] nArray = fpR[k];
                            int n = j;
                            nArray[n] = nArray[n] + 1;
                            continue;
                        }
                        if (inst.classValue().equals(this.positiveClass)) {
                            int[] nArray = fnR[k];
                            int n = j;
                            nArray[n] = nArray[n] + 1;
                            continue;
                        }
                        int[] nArray = tnR[k];
                        int n = j;
                        nArray[n] = nArray[n] + 1;
                    }
                }
            }
        }
        double originalF = new PerformanceMeasure(tp, tn, fp, fn).getFMeasure();
        this.importance = new double[data.noAttributes()];
        for (int k = 0; k < data.noAttributes(); ++k) {
            double[] g = new double[this.numPerturbations];
            for (int i = 0; i < this.numPerturbations; ++i) {
                g[i] = new PerformanceMeasure(tpR[k][i], tnR[k][i], fpR[k][i], tnR[k][i]).getFMeasure();
            }
            double avg = MathUtils.arithmicMean(g);
            this.importance[k] = originalF - avg;
        }
        ArrayUtils.add(this.importance, -ArrayUtils.min(this.importance));
        ArrayUtils.normalize(this.importance, ArrayUtils.max(this.importance));
    }

    @Override
    public double score(int attribute) {
        return this.importance[attribute];
    }

    @Override
    public int noAttributes() {
        return this.importance.length;
    }
}

