/*
 * Decompiled with CFR 0.152.
 */
package net.sf.javaml.classification;

import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
import net.sf.javaml.classification.AbstractClassifier;
import net.sf.javaml.clustering.SOM;
import net.sf.javaml.core.Dataset;
import net.sf.javaml.core.Instance;
import net.sf.javaml.distance.EuclideanDistance;
import net.sf.javaml.tools.DatasetTools;

public class SOM
extends AbstractClassifier {
    private static final long serialVersionUID = 6369609967132433683L;
    private net.sf.javaml.clustering.SOM som;
    private Instance[] centroids;
    private Vector<Map<Object, Double>> distribution;

    public SOM(int xdim, int ydim, SOM.GridType grid, int iterations, double learningRate, int initialRadius, SOM.LearningType learning, SOM.NeighbourhoodFunction nbf) {
        this.som = new net.sf.javaml.clustering.SOM(xdim, ydim, grid, iterations, learningRate, initialRadius, learning, nbf);
    }

    @Override
    public void buildClassifier(Dataset data) {
        Dataset[] clusters = this.som.cluster(data);
        this.centroids = new Instance[clusters.length];
        this.distribution = new Vector();
        for (int i = 0; i < clusters.length; ++i) {
            this.centroids[i] = DatasetTools.average(clusters[i]);
            this.distribution.add(this.distribution(data, clusters[i]));
        }
    }

    private Map<Object, Double> distribution(Dataset original, Dataset dataset) {
        HashMap<Object, Double> out = new HashMap<Object, Double>();
        for (Object e : original.classes()) {
            out.put(e, 0.0);
        }
        for (Instance instance : dataset) {
            out.put(instance.classValue(), (Double)out.get(instance.classValue()) + 1.0 / (double)dataset.size());
        }
        return out;
    }

    @Override
    public Map<Object, Double> classDistribution(Instance inst) {
        double min = Double.POSITIVE_INFINITY;
        EuclideanDistance ed = new EuclideanDistance();
        int index = 0;
        for (int i = 0; i < this.centroids.length; ++i) {
            double d = ed.measure(this.centroids[i], inst);
            if (!(d < min)) continue;
            d = min;
            index = i;
        }
        return this.distribution.get(index);
    }
}

