/*
 * Decompiled with CFR 0.152.
 */
package com.macrofocus.high_d.mds.tsne;

import com.macrofocus.high_d.mds.AbstractMDSEngine;
import com.macrofocus.high_d.mds.MDSEngineEvent;
import com.macrofocus.high_d.mds.MDSModel;
import com.macrofocus.high_d.mds.MutableMDSModel;
import com.macrofocus.high_d.mds.tsne.MatrixOperations;
import com.macrofocus.high_d.mds.tsne.MatrixOperationsFactory;
import com.macrofocus.molap.dataframe.matrix.Matrix;
import com.macrofocus.timer.CPExecutor;

public class TSNEMDSEngine
extends AbstractMDSEngine {
    private static MatrixOperations mo;
    private final MutableMDSModel mapModel;
    private final Matrix matrixModel;
    private final MatrixOperationsFactory matrixOperationsFactory;

    public TSNEMDSEngine(MutableMDSModel mapModel, Matrix matrixModel, MatrixOperationsFactory matrixOperationsFactory, CPExecutor executor) {
        super(executor);
        this.mapModel = mapModel;
        this.matrixModel = matrixModel;
        this.matrixOperationsFactory = matrixOperationsFactory;
    }

    @Override
    public MDSModel getModel() {
        return this.mapModel;
    }

    @Override
    protected Matrix getDistanceTable() {
        throw new UnsupportedOperationException();
    }

    @Override
    public CPExecutor.Command createRunCommand() {
        return new CPExecutor.Command(){
            int progress = -1;
            final double perplexity = 20.0;
            double[][] values;
            int n;
            double[][] y;
            double[][] dY;
            double[][] iY;
            double[][] gains;
            double[][] p;
            final double min_gain = 0.01;
            final int eta = 500;
            final double final_momentum = 0.8;
            final double initial_momentum = 0.5;
            double momentum = 0.5;

            public boolean execute() {
                if (this.progress < 0) {
                    if (mo == null) {
                        mo = TSNEMDSEngine.this.matrixOperationsFactory.createMatrixOperations();
                    }
                    this.values = new double[TSNEMDSEngine.this.matrixModel.getRowCount()][TSNEMDSEngine.this.matrixModel.getColumnCount()];
                    for (int row = 0; row < TSNEMDSEngine.this.matrixModel.getRowCount(); ++row) {
                        for (int column = 0; column < TSNEMDSEngine.this.matrixModel.getColumnCount(); ++column) {
                            this.values[row][column] = TSNEMDSEngine.this.matrixModel.getDouble(TSNEMDSEngine.this.matrixModel.getRowKey(row), TSNEMDSEngine.this.matrixModel.getColumnKey(column));
                        }
                    }
                }
                return this.tsne(this.values, 2, 20.0, 2000);
            }

            private boolean tsne(double[][] x, int no_dims, double perplexity, int max_iter) {
                if (this.progress < 0) {
                    this.n = x.length;
                    this.y = mo.rnorm(this.n, no_dims);
                    this.dY = mo.fillMatrix(this.n, no_dims, 0.0);
                    this.iY = mo.fillMatrix(this.n, no_dims, 0.0);
                    this.gains = mo.fillMatrix(this.n, no_dims, 1.0);
                    this.p = TSNEMDSEngine.this.x2p((double[][])x, (double)1.0E-5, (double)perplexity).p;
                    this.p = mo.plus(this.p, mo.transpose(this.p));
                    this.p = mo.scalarDivide(this.p, mo.sum(this.p));
                    this.p = mo.scalarMult(this.p, 4.0);
                    this.p = mo.maximum(this.p, 1.0E-12);
                }
                ++this.progress;
                if (this.progress < max_iter && !TSNEMDSEngine.this.isInterrupted()) {
                    double[][] sum_Y = mo.transpose(mo.sum(mo.square(this.y), 1));
                    double[][] num = mo.scalarInverse(mo.scalarPlus(mo.addRowVector(mo.transpose(mo.addRowVector(mo.scalarMult(mo.times(this.y, mo.transpose(this.y)), -2.0), sum_Y)), sum_Y), 1.0));
                    mo.assignAtIndex(num, mo.range(this.n), mo.range(this.n), 0.0);
                    double[][] Q = mo.scalarDivide(num, mo.sum(num));
                    Q = mo.maximum(Q, 1.0E-12);
                    double[][] L = mo.scalarMultiply(mo.minus(this.p, Q), num);
                    this.dY = mo.scalarMult(mo.times(mo.minus(mo.diag(mo.sum(L, 1)), L), this.y), 4.0);
                    this.momentum = this.progress < 20 ? 0.5 : 0.8;
                    this.gains = mo.plus(mo.scalarMultiply(mo.scalarPlus(this.gains, 0.2), mo.abs(mo.negate(mo.equal(mo.biggerThan(this.dY, 0.0), mo.biggerThan(this.iY, 0.0))))), mo.scalarMultiply(mo.scalarMult(this.gains, 0.8), mo.abs(mo.equal(mo.biggerThan(this.dY, 0.0), mo.biggerThan(this.iY, 0.0)))));
                    mo.assignAllLessThan(this.gains, 0.01, 0.01);
                    this.iY = mo.minus(mo.scalarMult(this.iY, this.momentum), mo.scalarMult(mo.scalarMultiply(this.gains, this.dY), 500.0));
                    this.y = mo.plus(this.y, this.iY);
                    this.y = mo.minus(this.y, mo.tile(mo.mean(this.y, 0), this.n, 1));
                    if (this.progress == 100) {
                        this.p = mo.scalarDivide(this.p, 4.0);
                    }
                    for (int i = 0; i < this.y.length; ++i) {
                        double[] doubles = this.y[i];
                        TSNEMDSEngine.this.mapModel.setX(i, doubles[0]);
                        TSNEMDSEngine.this.mapModel.setY(i, doubles[1]);
                    }
                    TSNEMDSEngine.this.notifyEngineIterated(new MDSEngineEvent(TSNEMDSEngine.this, this.progress));
                    ++this.progress;
                    return true;
                }
                for (int i = 0; i < this.y.length; ++i) {
                    double[] doubles = this.y[i];
                    TSNEMDSEngine.this.mapModel.setX(i, doubles[0]);
                    TSNEMDSEngine.this.mapModel.setY(i, doubles[1]);
                }
                TSNEMDSEngine.this.notifyEngineIterated(new MDSEngineEvent(TSNEMDSEngine.this, max_iter));
                TSNEMDSEngine.this.notifyEngineFinished(new MDSEngineEvent(TSNEMDSEngine.this, max_iter));
                return false;
            }
        };
    }

    private R x2p(double[][] x, double tol, double perplexity) {
        int n = x.length;
        double[][] sum_X = mo.sum(mo.square(x), 1);
        double[][] times = mo.scalarMult(mo.times(x, mo.transpose(x)), -2.0);
        double[][] prodSum = mo.addColumnVector(mo.transpose(times), sum_X);
        double[][] d = mo.addRowVector(prodSum, mo.transpose(sum_X));
        double[][] p = mo.fillMatrix(n, n, 0.0);
        double[] beta = mo.fillMatrix(n, n, 1.0)[0];
        double logU = Math.log(perplexity);
        for (int i = 0; i < n; ++i) {
            double betamin = Double.NEGATIVE_INFINITY;
            double betamax = Double.POSITIVE_INFINITY;
            double[][] di = mo.getValuesFromRow(d, i, mo.concatenate(mo.range(0, i), mo.range(i + 1, n)));
            R hbeta = this.Hbeta(di, beta[i]);
            double H = hbeta.h;
            double[][] thisP = hbeta.p;
            double Hdiff = H - logU;
            for (int tries = 0; Math.abs(Hdiff) > tol && tries < 50; ++tries) {
                if (Hdiff > 0.0) {
                    betamin = beta[i];
                    if (Double.isInfinite(betamax)) {
                        int n2 = i;
                        beta[n2] = beta[n2] * 2.0;
                    } else {
                        beta[i] = (beta[i] + betamax) / 2.0;
                    }
                } else {
                    betamax = beta[i];
                    if (Double.isInfinite(betamin)) {
                        int n3 = i;
                        beta[n3] = beta[n3] / 2.0;
                    } else {
                        beta[i] = (beta[i] + betamin) / 2.0;
                    }
                }
                hbeta = this.Hbeta(di, beta[i]);
                H = hbeta.h;
                thisP = hbeta.p;
                Hdiff = H - logU;
            }
            mo.assignValuesToRow(p, i, mo.concatenate(mo.range(0, i), mo.range(i + 1, n)), thisP[0]);
        }
        R r = new R();
        r.p = p;
        r.beta = beta;
        return r;
    }

    private R Hbeta(double[][] d, double beta) {
        double[][] P = mo.exp(mo.scalarMult(mo.scalarMult(d, beta), -1.0));
        double sumP = mo.sum(P);
        double H = Math.log(sumP) + beta * mo.sum(mo.scalarMultiply(d, P)) / sumP;
        P = mo.scalarDivide(P, sumP);
        R r = new R();
        r.h = H;
        r.p = P;
        return r;
    }

    static class R {
        double[][] p;
        double[] beta;
        double h;

        R() {
        }
    }
}

