/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.featureselection.scoring;

import be.abeel.util.HashMap2D;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.featureselection.FeatureScoring;
import net.sf.javaml.filter.normalize.NormalizeMidrange;

public class KullbackLeiblerDivergence
implements FeatureScoring {
    private double[] maxDivergence;
    private HashMap2D<Object, Object, double[]> pairWiseDivergence = new HashMap2D();
    private int bins;

    public KullbackLeiblerDivergence() {
        this(100);
    }

    public KullbackLeiblerDivergence(int i) {
        this.bins = i;
    }

    @Override
    public void build(Dataset data) {
        double[] d;
        this.maxDivergence = new double[data.noAttributes()];
        NormalizeMidrange nm = new NormalizeMidrange(this.bins / 2, (double)this.bins - 1.0E-6);
        nm.build(data);
        nm.filter(data);
        for (Object e : data.classes()) {
            for (Object e2 : data.classes()) {
                if (e.equals(e2)) continue;
                d = this.pairWise(e, e2, data);
                this.pairWiseDivergence.put(e, e2, (Object)d);
            }
        }
        for (Object e : data.classes()) {
            for (Object e3 : data.classes()) {
                d = (double[])this.pairWiseDivergence.get(e, e3);
                if (d == null) continue;
                for (int i = 0; i < d.length; ++i) {
                    if (!(d[i] > this.maxDivergence[i])) continue;
                    this.maxDivergence[i] = d[i];
                }
            }
        }
    }

    private double[] pairWise(Object p, Object q, Dataset data) {
        int i;
        double[] divergence = new double[data.noAttributes()];
        double maxSum = 0.0;
        for (i = 0; i < data.noAttributes(); ++i) {
            double sum = 0.0;
            double[] countQ = new double[this.bins];
            double[] countP = new double[this.bins];
            double pCount = 0.0;
            double qCount = 0.0;
            for (Instance inst : data) {
                if (inst.classValue().equals(q)) {
                    int n = (int)inst.value(i);
                    countQ[n] = countQ[n] + 1.0;
                    qCount += 1.0;
                }
                if (!inst.classValue().equals(p)) continue;
                int n = (int)inst.value(i);
                countP[n] = countP[n] + 1.0;
                pCount += 1.0;
            }
            for (int j = 0; j < countP.length; ++j) {
                int n = j;
                countP[n] = countP[n] / pCount;
                int n2 = j;
                countQ[n2] = countQ[n2] / qCount;
                if (countP[j] == 0.0) {
                    countP[j] = 1.0E-7;
                }
                if (countQ[j] == 0.0) {
                    countQ[j] = 1.0E-7;
                }
                sum += countP[j] * Math.log(countP[j] / countQ[j]);
            }
            divergence[i] = sum;
            if (!(sum > maxSum)) continue;
            maxSum = sum;
        }
        i = 0;
        while (i < data.noAttributes()) {
            int n = i++;
            divergence[n] = divergence[n] / maxSum;
        }
        return divergence;
    }

    @Override
    public double score(int attribute) {
        return this.maxDivergence[attribute];
    }

    @Override
    public int noAttributes() {
        return this.maxDivergence.length;
    }
}

