/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.classification;

import java.util.HashMap;
import java.util.Map;
import net.sf.javaml.classification.AbstractClassifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.exception.TrainingRequiredException;
import net.sf.javaml.core.kdtree.KDTree;
import net.sf.javaml.tools.InstanceTools;

public class KDtreeKNN
extends AbstractClassifier {
    private static final long serialVersionUID = 1560149339188819924L;
    private int k;
    private KDTree tree;
    private Dataset training;

    public KDtreeKNN(int k) {
        this.k = k;
    }

    @Override
    public void buildClassifier(Dataset data) {
        this.training = data;
        this.tree = new KDTree(data.noAttributes());
        for (Instance inst : data) {
            this.tree.insert(InstanceTools.array(inst), inst);
        }
    }

    @Override
    public Map<Object, Double> classDistribution(Instance instance) {
        if (this.training == null) {
            throw new TrainingRequiredException();
        }
        Object[] neighbors = this.tree.nearest(InstanceTools.array(instance), this.k);
        HashMap<Object, Double> out = new HashMap<Object, Double>();
        for (Object e : this.training.classes()) {
            out.put(e, 0.0);
        }
        for (Object o : neighbors) {
            Instance i = (Instance)o;
            out.put(i.classValue(), (Double)out.get(i.classValue()) + 1.0);
        }
        double min = this.k;
        double max = 0.0;
        for (Object key : out.keySet()) {
            double val = (Double)out.get(key);
            if (val > max) {
                max = val;
            }
            if (!(val < min)) continue;
            min = val;
        }
        if (max != min) {
            for (Object key : out.keySet()) {
                out.put(key, (out.get(key) - min) / (max - min));
            }
        }
        return out;
    }
}

