/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.classification;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import net.sf.javaml.classification.AbstractClassifier;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.core.exception.TrainingRequiredException;
import net.sf.javaml.distance.DistanceMeasure;
import net.sf.javaml.distance.EuclideanDistance;

public class KNearestNeighbors
extends AbstractClassifier {
    private static final long serialVersionUID = 1560149339188819924L;
    private Dataset training;
    private int k;
    private DistanceMeasure dm;

    public KNearestNeighbors(int k) {
        this(k, new EuclideanDistance());
    }

    public KNearestNeighbors(int k, DistanceMeasure dm) {
        this.k = k;
        this.dm = dm;
    }

    @Override
    public void buildClassifier(Dataset data) {
        this.training = data;
    }

    @Override
    public Map<Object, Double> classDistribution(Instance instance) {
        if (this.training == null) {
            throw new TrainingRequiredException();
        }
        Set<Instance> neighbors = this.training.kNearest(this.k, instance, this.dm);
        HashMap<Object, Double> out = new HashMap<Object, Double>();
        for (Object e : this.training.classes()) {
            out.put(e, 0.0);
        }
        for (Instance instance2 : neighbors) {
            out.put(instance2.classValue(), (Double)out.get(instance2.classValue()) + 1.0);
        }
        double min = this.k;
        double max = 0.0;
        for (Object key : out.keySet()) {
            double val = (Double)out.get(key);
            if (val > max) {
                max = val;
            }
            if (!(val < min)) continue;
            min = val;
        }
        if (max != min) {
            for (Object key : out.keySet()) {
                out.put(key, (out.get(key) - min) / (max - min));
            }
        }
        return out;
    }
}

