/*
 * Decompiled with CFR 0.152.
 */
package visualizer.dimensionreduction;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import visualizer.dimensionreduction.DimensionalityReduction;
import visualizer.matrix.Matrix;
import visualizer.projection.distance.Dissimilarity;

public class KMeansReduction
extends DimensionalityReduction {
    private float[][] centroids;
    private boolean isNormalized = false;

    public KMeansReduction(int targetDimension) {
        super(targetDimension);
    }

    @Override
    protected float[][] execute(Matrix matrix, Dissimilarity diss) throws IOException {
        float[][] points = matrix.toMatrix();
        ArrayList<ArrayList<Integer>> dimClusters = this.doClustering(points);
        float[][] newPoints = new float[points.length][];
        for (int i = 0; i < newPoints.length; ++i) {
            newPoints[i] = new float[dimClusters.size()];
            Arrays.fill(newPoints[i], 0.0f);
        }
        for (int j = 0; j < dimClusters.size(); ++j) {
            ArrayList<Integer> cluster = dimClusters.get(j);
            for (int i = 0; i < newPoints.length; ++i) {
                for (int k = 0; k < cluster.size(); ++k) {
                    float[] fArray = newPoints[i];
                    int n = j;
                    fArray[n] = fArray[n] + points[i][cluster.get(k)];
                }
                float[] fArray = newPoints[i];
                int n = j;
                fArray[n] = fArray[n] / (float)cluster.size();
            }
        }
        return newPoints;
    }

    private ArrayList<ArrayList<Integer>> doClustering(float[][] points) throws IOException {
        int i;
        ArrayList<ArrayList<Integer>> clusters = new ArrayList<ArrayList<Integer>>();
        boolean isCentroidModified = true;
        int numberIterations = 0;
        this.centroids = new float[points.length][];
        for (i = 0; i < points.length; ++i) {
            this.centroids[i] = new float[this.targetDimension];
            for (int j = 0; j < this.targetDimension; ++j) {
                this.centroids[i][j] = points[i][j];
            }
        }
        while (isCentroidModified && numberIterations < 30) {
            clusters.clear();
            for (i = 0; i < this.targetDimension; ++i) {
                clusters.add(new ArrayList());
            }
            for (int dim = 0; dim < points[0].length; ++dim) {
                int nearestCentroid = 0;
                float distance1 = this.calculateDistance(points, dim, nearestCentroid);
                for (int cluster = 1; cluster < this.targetDimension; ++cluster) {
                    float distance2 = this.calculateDistance(points, dim, cluster);
                    if (!(distance1 > distance2)) continue;
                    nearestCentroid = cluster;
                    distance1 = distance2;
                }
                this.updateCentroid(clusters.get(nearestCentroid).size(), nearestCentroid, dim, points);
                clusters.get(nearestCentroid).add(dim);
            }
            float[][] oldCentroids = this.centroids;
            isCentroidModified = this.isCentroidModified(oldCentroids);
            ++numberIterations;
        }
        for (int c = 0; c < clusters.size(); ++c) {
            if (clusters.get(c).size() != 0) continue;
            clusters.remove(c);
            --c;
        }
        return clusters;
    }

    public float calculateDistance(float[][] points, int dim, int nearestCentroid) throws IOException {
        if (!this.isNormalized) {
            this.normalize(points);
        }
        float norm = 0.0f;
        for (int i = 0; i < points.length; ++i) {
            norm += this.centroids[i][nearestCentroid] * this.centroids[i][nearestCentroid];
        }
        norm = (float)Math.sqrt(norm);
        float sim = 0.0f;
        for (int i = 0; i < points.length; ++i) {
            sim += points[i][dim] * this.centroids[i][nearestCentroid];
        }
        return (float)Math.sqrt(Math.abs(2.0 * (1.0 - (double)(sim /= norm))));
    }

    private void updateCentroid(int nrOldPoints, int centroid, int dim, float[][] points) {
        for (int i = 0; i < points.length; ++i) {
            this.centroids[i][centroid] = (this.centroids[i][centroid] * (float)nrOldPoints + points[i][dim]) / (float)(nrOldPoints + 1);
        }
    }

    private void updateCentroids(float[][] points, ArrayList<ArrayList<Integer>> clusters) {
        this.centroids = new float[points.length][];
        for (int i = 0; i < points.length; ++i) {
            this.centroids[i] = new float[this.targetDimension];
            Arrays.fill(this.centroids[i], 0.0f);
        }
        for (int j = 0; j < clusters.size(); ++j) {
            int i;
            ArrayList<Integer> cluster = clusters.get(j);
            if (cluster.size() > 0) {
                for (i = 0; i < this.centroids.length; ++i) {
                    for (int k = 0; k < cluster.size(); ++k) {
                        float[] fArray = this.centroids[i];
                        int n = j;
                        fArray[n] = fArray[n] + points[i][cluster.get(k)];
                    }
                    float[] fArray = this.centroids[i];
                    int n = j;
                    fArray[n] = fArray[n] / (float)cluster.size();
                }
                continue;
            }
            for (i = 0; i < this.centroids.length; ++i) {
                this.centroids[i][j] = points[i][j];
            }
        }
    }

    private boolean isCentroidModified(float[][] oldCentroids) {
        for (int centroid = 0; centroid < oldCentroids.length; ++centroid) {
            for (int coord = 0; coord < oldCentroids[centroid].length; ++coord) {
                if (oldCentroids[centroid][coord] == this.centroids[centroid][coord]) continue;
                return true;
            }
        }
        return false;
    }

    private void normalize(float[][] points) {
        int i;
        int j;
        this.isNormalized = true;
        float[] termSum = new float[points[0].length];
        Arrays.fill(termSum, 0.0f);
        for (j = 0; j < points[0].length; ++j) {
            for (i = 0; i < points.length; ++i) {
                int n = j;
                termSum[n] = termSum[n] + points[i][j] * points[i][j];
            }
            termSum[j] = (float)Math.sqrt(termSum[j]);
        }
        for (j = 0; j < points[0].length; ++j) {
            for (i = 0; i < points.length; ++i) {
                points[i][j] = (double)termSum[j] != 0.0 ? points[i][j] / termSum[j] : 0.0f;
            }
        }
    }
}

