/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.classification.tree;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.SortedSet;
import java.util.Vector;
import net.sf.javaml.classification.Classifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.DefaultDataset;
import net.sf.javaml.core.Instance;

public class RandomTree
implements Classifier {
    private static final long serialVersionUID = -6421557885832628441L;
    private int noSplitAttributes = -1;
    private Random rg = null;
    private float[] rightCenter = null;
    private float[] leftCenter = null;
    private Object finalClass = null;
    private RandomTree leftChild = null;
    private RandomTree rightChild = null;
    private Vector<Integer> splitAttributes = null;
    private SortedSet<Object> parentClasses = null;

    private RandomTree(int attributes, Random rg, SortedSet<Object> classes) {
        this.rg = rg;
        this.noSplitAttributes = attributes;
        this.parentClasses = classes;
    }

    public RandomTree(int attributes, Random rg) {
        this(attributes, rg, null);
    }

    @Override
    public void buildClassifier(Dataset data) {
        if (this.parentClasses == null) {
            this.parentClasses = data.classes();
        }
        if (data.classes().size() == 1) {
            this.finalClass = data.classes().first();
            data.clear();
            return;
        }
        DefaultDataset left = null;
        DefaultDataset right = null;
        boolean correctSplit = false;
        int iterationCount = 0;
        while (!correctSplit) {
            ++iterationCount;
            this.splitAttributes = new Vector();
            for (int i = 0; i < data.noAttributes(); ++i) {
                this.splitAttributes.add(i);
            }
            while (this.splitAttributes.size() / (iterationCount * iterationCount) > this.noSplitAttributes) {
                this.splitAttributes.remove(this.rg.nextInt(this.splitAttributes.size()));
            }
            int count0 = 0;
            int count1 = 0;
            this.leftCenter = new float[this.splitAttributes.size()];
            this.rightCenter = new float[this.splitAttributes.size()];
            for (Instance inst : data) {
                int j;
                if (data.classIndex(inst.classValue()) == 0) {
                    ++count0;
                    for (j = 0; j < this.splitAttributes.size(); ++j) {
                        int n = j;
                        this.leftCenter[n] = (float)((double)this.leftCenter[n] + inst.value(this.splitAttributes.get(j)));
                    }
                    continue;
                }
                ++count1;
                for (j = 0; j < this.splitAttributes.size(); ++j) {
                    int n = j;
                    this.rightCenter[n] = (float)((double)this.rightCenter[n] + inst.value(this.splitAttributes.get(j)));
                }
            }
            int i = 0;
            while (i < this.splitAttributes.size()) {
                int n = i;
                this.leftCenter[n] = this.leftCenter[n] / (float)count0;
                int n2 = i++;
                this.rightCenter[n2] = this.rightCenter[n2] / (float)count1;
            }
            double[] tmp = new double[this.splitAttributes.size()];
            left = new DefaultDataset();
            right = new DefaultDataset();
            for (Instance inst : data) {
                double distRight;
                for (int i2 = 0; i2 < this.splitAttributes.size(); ++i2) {
                    tmp[i2] = inst.value(this.splitAttributes.get(i2));
                }
                double distLeft = this.dist(tmp, this.leftCenter);
                if (distLeft > (distRight = this.dist(tmp, this.rightCenter))) {
                    right.add(inst);
                    continue;
                }
                left.add(inst);
            }
            correctSplit = left.size() != 0 && right.size() != 0;
            if (correctSplit || iterationCount * iterationCount * this.noSplitAttributes <= data.noAttributes()) continue;
            Vector<Object> possibleClasses = new Vector<Object>();
            possibleClasses.addAll(data.classes());
            this.finalClass = possibleClasses.get(this.rg.nextInt(possibleClasses.size()));
            data.clear();
            left = null;
            right = null;
            return;
        }
        this.leftChild = new RandomTree(this.noSplitAttributes, this.rg, this.parentClasses);
        this.leftChild.buildClassifier(left);
        this.rightChild = new RandomTree(this.noSplitAttributes, this.rg, this.parentClasses);
        this.rightChild.buildClassifier(right);
    }

    private double dist(double[] a, float[] b) {
        double sum = 0.0;
        for (int i = 0; i < a.length; ++i) {
            sum += Math.abs(a[i] - (double)b[i]);
        }
        return sum;
    }

    @Override
    public Object classify(Instance instance) {
        double distRight;
        if (this.finalClass != null) {
            return this.finalClass;
        }
        assert (this.rightCenter != null);
        assert (this.leftCenter != null);
        assert (this.leftChild != null);
        assert (this.rightChild != null);
        assert (this.splitAttributes != null);
        double[] tmp = new double[this.noSplitAttributes];
        for (int i = 0; i < this.noSplitAttributes; ++i) {
            tmp[i] = instance.value(this.splitAttributes.get(i));
        }
        double distLeft = this.dist(tmp, this.leftCenter);
        if (distLeft > (distRight = this.dist(tmp, this.rightCenter))) {
            return this.rightChild.classify(instance);
        }
        return this.leftChild.classify(instance);
    }

    @Override
    public Map<Object, Double> classDistribution(Instance instance) {
        HashMap<Object, Double> out = new HashMap<Object, Double>();
        for (Object e : this.parentClasses) {
            out.put(e, 0.0);
        }
        out.put(this.classify(instance), 1.0);
        return out;
    }
}

