/*
 * Decompiled with CFR 0.152.
 */
package jalview.analysis;

import jalview.bin.Console;
import jalview.math.Matrix;
import jalview.math.MatrixI;
import jalview.math.MiscMath;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;

public class ccAnalysis {
    private byte dim = 0;
    private MatrixI scoresOld;

    public ccAnalysis(MatrixI scores, byte dim) {
        for (int i = 0; i < scores.height(); ++i) {
            for (int j = 0; j < scores.width(); ++j) {
                if (Double.isNaN(scores.getValue(i, j))) continue;
                scores.setValue(i, j, (double)Math.round(scores.getValue(i, j) * 10000.0) / 10000.0);
            }
        }
        this.scoresOld = scores;
        this.dim = dim;
    }

    private int[] initialiseDistrusts(byte[] hSigns, MatrixI scores) {
        int[] distrustScores = new int[scores.width()];
        for (int i = 0; i < scores.width(); ++i) {
            byte hASign = hSigns[i];
            int conHypNum = 0;
            int proHypNum = 0;
            for (int j = 0; j < scores.width(); ++j) {
                double cell = scores.getRow(i)[j];
                byte hBSign = hSigns[j];
                if (Double.isNaN(cell)) continue;
                byte cellSign = (byte)Math.signum(cell);
                if (cellSign == hASign * hBSign) {
                    ++proHypNum;
                    continue;
                }
                ++conHypNum;
            }
            distrustScores[i] = conHypNum - proHypNum;
        }
        return distrustScores;
    }

    private byte[] optimiseHypothesis(byte[] hSigns, int[] distrustScores, MatrixI scores) {
        int[] maxes = MiscMath.findMax(distrustScores);
        int maxDistrustIndex = maxes[0];
        int maxDistrust = maxes[1];
        if (maxDistrust > 0) {
            int n = maxDistrustIndex;
            hSigns[n] = (byte)(hSigns[n] * -1);
            int n2 = maxDistrustIndex;
            distrustScores[n2] = distrustScores[n2] * -1;
            byte hASign = hSigns[maxDistrustIndex];
            for (int NOTmaxDistrustIndex = 0; NOTmaxDistrustIndex < distrustScores.length; ++NOTmaxDistrustIndex) {
                if (NOTmaxDistrustIndex == maxDistrustIndex) continue;
                byte hBSign = hSigns[NOTmaxDistrustIndex];
                double cell = scores.getValue(maxDistrustIndex, NOTmaxDistrustIndex);
                if (Double.isNaN(cell)) continue;
                byte cellSign = (byte)Math.signum(cell);
                if (cellSign == hASign * hBSign) {
                    int n3 = NOTmaxDistrustIndex;
                    distrustScores[n3] = distrustScores[n3] - 2;
                    continue;
                }
                int n4 = NOTmaxDistrustIndex;
                distrustScores[n4] = distrustScores[n4] + 2;
            }
            return this.optimiseHypothesis(hSigns, distrustScores, scores);
        }
        return hSigns;
    }

    public MatrixI run() throws Exception {
        MatrixI eigenMatrix = this.scoresOld.copy();
        MatrixI repMatrix = this.scoresOld.copy();
        try {
            System.out.println("Input correlation matrix:");
            eigenMatrix.print(System.out, "%1.4f ");
            int matrixWidth = eigenMatrix.width();
            int matrixElementsTotal = (int)Math.pow(matrixWidth, 2.0);
            float correctionFactor = (float)(matrixElementsTotal - eigenMatrix.countNaN()) / (float)matrixElementsTotal;
            byte[] hSigns = new byte[matrixWidth];
            Arrays.fill(hSigns, (byte)1);
            hSigns = this.optimiseHypothesis(hSigns, this.initialiseDistrusts(hSigns, eigenMatrix), eigenMatrix);
            double[] hAbs = MiscMath.sqrt(eigenMatrix.absolute().meanRow());
            double[] hValues = MiscMath.elementwiseMultiply(hSigns, hAbs);
            ArrayList<int[]> estimatedPositions = new ArrayList<int[]>();
            for (int rowIndex = 0; rowIndex < matrixWidth - 1; ++rowIndex) {
                for (int columnIndex = rowIndex + 1; columnIndex < matrixWidth; ++columnIndex) {
                    double cell = eigenMatrix.getValue(rowIndex, columnIndex);
                    if (!Double.isNaN(cell)) continue;
                    cell = hValues[rowIndex] * hValues[columnIndex];
                    eigenMatrix.setValue(rowIndex, columnIndex, cell);
                    eigenMatrix.setValue(columnIndex, rowIndex, cell);
                    estimatedPositions.add(new int[]{rowIndex, columnIndex});
                }
            }
            int diagonalIndex = 0;
            while (diagonalIndex < matrixWidth) {
                double cell = Math.pow(hValues[diagonalIndex], 2.0);
                eigenMatrix.setValue(diagonalIndex, diagonalIndex, cell);
                estimatedPositions.add(new int[]{diagonalIndex, diagonalIndex++});
            }
            System.out.print("initial values: [ ");
            for (double h : hValues) {
                System.out.print(String.format("%1.4f, ", h));
            }
            System.out.println(" ]");
            double[] hValuesOld = new double[matrixWidth];
            int iterationCount = 0;
            while (true) {
                for (int hIndex = 0; hIndex < matrixWidth; ++hIndex) {
                    double newH;
                    hValues[hIndex] = newH = Arrays.stream(MiscMath.elementwiseMultiply(hValues, eigenMatrix.getRow(hIndex))).sum() / Arrays.stream(MiscMath.elementwiseMultiply(hValues, hValues)).sum();
                }
                System.out.print(String.format("iteration %d: [ ", iterationCount));
                for (double h : hValues) {
                    System.out.print(String.format("%1.4f, ", h));
                }
                System.out.println(" ]");
                Object hIndex = estimatedPositions.iterator();
                while (hIndex.hasNext()) {
                    int[] pair = (int[])hIndex.next();
                    double newVal = hValues[pair[0]] * hValues[pair[1]];
                    eigenMatrix.setValue(pair[0], pair[1], newVal);
                    eigenMatrix.setValue(pair[1], pair[0], newVal);
                }
                ++iterationCount;
                if (MiscMath.allClose(hValues, hValuesOld, 0.0, 1.0E-5, false)) break;
                System.arraycopy(hValues, 0, hValuesOld, 0, hValues.length);
            }
            eigenMatrix.tred();
            eigenMatrix.tqli();
            System.out.println("eigenmatrix");
            eigenMatrix.print(System.out, "%8.2f");
            System.out.println();
            System.out.println("uncorrected eigenvalues");
            eigenMatrix.printD(System.out, "%2.4f ");
            System.out.println();
            double[] eigenVals = eigenMatrix.getD();
            TreeMap eigenPairs = new TreeMap(Comparator.reverseOrder());
            for (int i = 0; i < eigenVals.length; ++i) {
                eigenPairs.put(eigenVals[i], i);
            }
            double[][] _repMatrix = new double[eigenVals.length][this.dim];
            double[][] _oldMatrix = new double[eigenVals.length][this.dim];
            double[] correctedEigenValues = new double[this.dim];
            byte l = 0;
            for (Map.Entry pair : eigenPairs.entrySet()) {
                double eigenValue = (Double)pair.getKey();
                int column = (Integer)pair.getValue();
                double[] eigenVector = eigenMatrix.getColumn(column);
                if (l >= 1) {
                    eigenValue /= (double)correctionFactor;
                }
                correctedEigenValues[l] = eigenValue;
                for (int j = 0; j < eigenVector.length; ++j) {
                    _repMatrix[j][l] = eigenValue < 0.0 ? 0.0 : -Math.sqrt(eigenValue) * eigenVector[j];
                    double tmpOldScore = this.scoresOld.getColumn(column)[j];
                    _oldMatrix[j][this.dim - l - 1] = Double.isNaN(tmpOldScore) ? 0.0 : tmpOldScore;
                }
                if (++l < this.dim) continue;
                break;
            }
            System.out.println("correctedEigenValues");
            MiscMath.print(correctedEigenValues, "%2.4f ");
            repMatrix = new Matrix(_repMatrix);
            repMatrix.setD(correctedEigenValues);
            Matrix oldMatrix = new Matrix(_oldMatrix);
            MatrixI dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
            double rmsd = this.scoresOld.rmsd(dotMatrix);
            System.out.println("iteration, rmsd, maxDiff, rmsdDiff");
            System.out.println(String.format("0, %8.5f, -, -", rmsd));
            for (int iteration = 1; iteration < 21; ++iteration) {
                int j;
                MatrixI repMatrixOLD = repMatrix.copy();
                MatrixI dotMatrixOLD = dotMatrix.copy();
                for (int hAIndex = 0; hAIndex < oldMatrix.height(); ++hAIndex) {
                    double[] row = oldMatrix.getRow(hAIndex);
                    double[] hA = repMatrix.getRow(hAIndex);
                    double[] hAlsm = this.leastSquaresOptimisation(repMatrix, this.scoresOld, hAIndex);
                    for (j = 0; j < repMatrix.width(); ++j) {
                        repMatrix.setValue(hAIndex, j, hAlsm[j]);
                    }
                }
                dotMatrix = repMatrix.postMultiply(repMatrix.transpose());
                rmsd = this.scoresOld.rmsd(dotMatrix);
                MatrixI diff = repMatrix.subtract(repMatrixOLD).absolute();
                double maxDiff = 0.0;
                for (int i = 0; i < diff.height(); ++i) {
                    for (j = 0; j < diff.width(); ++j) {
                        maxDiff = diff.getValue(i, j) > maxDiff ? diff.getValue(i, j) : maxDiff;
                    }
                }
                double rmsdDiff = dotMatrix.rmsd(dotMatrixOLD);
                System.out.println(String.format("%d, %8.5f, %8.5f, %8.5f", iteration, rmsd, maxDiff, rmsdDiff));
                if (Math.abs(maxDiff) > 1.0E-6) continue;
                repMatrix = repMatrixOLD.copy();
                break;
            }
        }
        catch (Exception q) {
            Console.error("Error computing cc_analysis:  " + q.getMessage());
            q.printStackTrace();
        }
        System.out.println("final coordinates:");
        repMatrix.print(System.out, "%1.8f ");
        return repMatrix;
    }

    private double[] originalToEquasionSystem(double[] hA, MatrixI repMatrix, MatrixI scoresOld, int hAIndex) {
        double[] originalRow = scoresOld.getRow(hAIndex);
        int nans = MiscMath.countNaN(originalRow);
        double[] result = new double[originalRow.length - nans];
        int resultIndex = 0;
        for (int hBIndex = 0; hBIndex < originalRow.length; ++hBIndex) {
            double pairwiseCC = originalRow[hBIndex];
            if (Double.isNaN(pairwiseCC)) continue;
            double[] hB = repMatrix.getRow(hBIndex);
            result[resultIndex++] = MiscMath.sum(MiscMath.elementwiseMultiply(hA, hB)) - pairwiseCC;
        }
        return result;
    }

    private MatrixI approximateDerivative(MatrixI repMatrix, MatrixI scoresOld, int hAIndex) {
        double[] hA = repMatrix.getRow(hAIndex);
        double[] f0 = this.originalToEquasionSystem(hA, repMatrix, scoresOld, hAIndex);
        double[] signX0 = new double[hA.length];
        double[] xAbs = new double[hA.length];
        for (int i = 0; i < hA.length; ++i) {
            signX0[i] = hA[i] >= 0.0 ? 1.0 : -1.0;
            xAbs[i] = Math.abs(hA[i]) >= 1.0 ? Math.abs(hA[i]) : 1.0;
        }
        double rstep = Math.pow(Math.ulp(1.0), 0.5);
        double[] h = new double[hA.length];
        for (int i = 0; i < hA.length; ++i) {
            h[i] = rstep * signX0[i] * xAbs[i];
        }
        int m = f0.length;
        int n = hA.length;
        double[][] jTransposed = new double[n][m];
        for (int i = 0; i < h.length; ++i) {
            double[] x = new double[h.length];
            System.arraycopy(hA, 0, x, 0, h.length);
            int n2 = i;
            x[n2] = x[n2] + h[i];
            double dx = x[i] - hA[i];
            double[] df = this.originalToEquasionSystem(x, repMatrix, scoresOld, hAIndex);
            for (int j = 0; j < df.length; ++j) {
                int n3 = j;
                df[n3] = df[n3] - f0[j];
                jTransposed[i][j] = df[j] / dx;
            }
        }
        MatrixI J = new Matrix(jTransposed).transpose();
        return J;
    }

    private double[] phiAndDerivative(double alpha, double[] suf, double[] s, double Delta) {
        double[] denom = MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha);
        double pNorm = MiscMath.norm(MiscMath.elementwiseDivide(suf, denom));
        double phi = pNorm - Delta;
        double phiPrime = -MiscMath.sum(MiscMath.elementwiseDivide(MiscMath.elementwiseMultiply(suf, suf), MiscMath.elementwiseMultiply(MiscMath.elementwiseMultiply(denom, denom), denom))) / pNorm;
        return new double[]{phi, phiPrime};
    }

    private TrustRegion solveLsqTrustRegion(int n, int m, double[] uf, double[] s, MatrixI V, double Delta, double alpha) {
        int iteration;
        double[] p;
        double[] suf = MiscMath.elementwiseMultiply(s, uf);
        boolean fullRank = false;
        if (m >= n) {
            double threshold = s[0] * Math.ulp(1.0) * (double)m;
            boolean bl = fullRank = s[s.length - 1] > threshold;
        }
        if (fullRank && MiscMath.norm(p = MiscMath.elementwiseMultiply(V.sumProduct(MiscMath.elementwiseDivide(uf, s)), -1.0)) <= Delta) {
            TrustRegion result = new TrustRegion(p, 0.0, 0);
            return result;
        }
        double alphaUpper = MiscMath.norm(suf) / Delta;
        double alphaLower = 0.0;
        if (fullRank) {
            double[] phiAndPrime = this.phiAndDerivative(0.0, suf, s, Delta);
            alphaLower = -phiAndPrime[0] / phiAndPrime[1];
        }
        alpha = !fullRank && alpha == 0.0 ? (alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5))) : alpha;
        for (iteration = 0; iteration < 10; ++iteration) {
            alpha = alpha < alphaLower || alpha > alphaUpper ? (alpha = Math.max(0.001 * alphaUpper, Math.pow(alphaLower * alphaUpper, 0.5))) : alpha;
            double[] phiAndPrime = this.phiAndDerivative(alpha, suf, s, Delta);
            double phi = phiAndPrime[0];
            double phiPrime = phiAndPrime[1];
            alphaUpper = phi < 0.0 ? alpha : alphaUpper;
            double ratio = phi / phiPrime;
            alphaLower = Math.max(alphaLower, alpha - ratio);
            alpha -= (phi + Delta) * ratio / Delta;
            if (Math.abs(phi) < 0.01 * Delta) break;
        }
        double[] tmp = MiscMath.elementwiseDivide(suf, MiscMath.elementwiseAdd(MiscMath.elementwiseMultiply(s, s), alpha));
        double[] p2 = MiscMath.elementwiseMultiply(V.sumProduct(tmp), -1.0);
        p2 = MiscMath.elementwiseMultiply(p2, Delta / MiscMath.norm(p2));
        TrustRegion result = new TrustRegion(p2, alpha, iteration + 1);
        return result;
    }

    private double evaluateQuadratic(MatrixI J, double[] g, double[] s) {
        double[] Js = J.sumProduct(s);
        double q = MiscMath.dot(Js, Js);
        double l = MiscMath.dot(s, g);
        return 0.5 * q + l;
    }

    private double[] updateTrustRegionRadius(double Delta, double actualReduction, double predictedReduction, double stepNorm, boolean boundHit) {
        double ratio = 0.0;
        ratio = predictedReduction > 0.0 ? actualReduction / predictedReduction : (predictedReduction == 0.0 && actualReduction == 0.0 ? 1.0 : 0.0);
        if (ratio < 0.25) {
            Delta = 0.25 * stepNorm;
        } else if (ratio > 0.75 && boundHit) {
            Delta *= 2.0;
        }
        return new double[]{Delta, ratio};
    }

    private double[] trf(MatrixI repMatrix, MatrixI scoresOld, int index, MatrixI J) {
        double[] hA = repMatrix.getRow(index);
        double[] f0 = this.originalToEquasionSystem(hA, repMatrix, scoresOld, index);
        int nfev = 1;
        int m = J.height();
        int n = J.width();
        double cost = 0.5 * MiscMath.dot(f0, f0);
        double[] g = J.transpose().sumProduct(f0);
        double Delta = MiscMath.norm(hA);
        int maxNfev = hA.length * 100;
        double alpha = 0.0;
        double gNorm = 0.0;
        boolean terminationStatus = false;
        int iteration = 0;
        while (true) {
            gNorm = MiscMath.norm(g);
            if (terminationStatus || nfev == maxNfev) break;
            SingularValueDecomposition svd = new SingularValueDecomposition((RealMatrix)new Array2DRowRealMatrix(J.asArray()));
            Matrix U = new Matrix(svd.getU().getData());
            double[] s = svd.getSingularValues();
            MatrixI V = new Matrix(svd.getV().getData()).transpose();
            double[] uf = U.transpose().sumProduct(f0);
            double actualReduction = -1.0;
            double[] xNew = new double[hA.length];
            double[] fNew = new double[f0.length];
            double costNew = 0.0;
            double stepHnorm = 0.0;
            while (actualReduction <= 0.0 && nfev < maxNfev) {
                TrustRegion trustRegion = this.solveLsqTrustRegion(n, m, uf, s, V, Delta, alpha);
                double[] stepH = trustRegion.getStep();
                alpha = trustRegion.getAlpha();
                int nIterations = trustRegion.getIteration();
                double predictedReduction = -this.evaluateQuadratic(J, g, stepH);
                xNew = MiscMath.elementwiseAdd(hA, stepH);
                fNew = this.originalToEquasionSystem(xNew, repMatrix, scoresOld, index);
                ++nfev;
                stepHnorm = MiscMath.norm(stepH);
                if (MiscMath.countNaN(fNew) > 0) {
                    Delta = 0.25 * stepHnorm;
                    continue;
                }
                costNew = 0.5 * MiscMath.dot(fNew, fNew);
                actualReduction = cost - costNew;
                double[] updatedTrustRegion = this.updateTrustRegionRadius(Delta, actualReduction, predictedReduction, stepHnorm, stepHnorm > 0.95 * Delta);
                double DeltaNew = updatedTrustRegion[0];
                double ratio = updatedTrustRegion[1];
                boolean ftolSatisfied = actualReduction < 1.0E-8 * cost && ratio > 0.25;
                boolean xtolSatisfied = stepHnorm < 1.0E-8 * (1.0E-8 + MiscMath.norm(hA));
                boolean bl = terminationStatus = ftolSatisfied || xtolSatisfied;
                if (terminationStatus) break;
                alpha *= Delta / DeltaNew;
                Delta = DeltaNew;
            }
            if (actualReduction > 0.0) {
                hA = xNew;
                f0 = fNew;
                cost = costNew;
                J = this.approximateDerivative(repMatrix, scoresOld, index);
                g = J.transpose().sumProduct(f0);
            } else {
                stepHnorm = 0.0;
                actualReduction = 0.0;
            }
            ++iteration;
        }
        return hA;
    }

    private double[] leastSquaresOptimisation(MatrixI repMatrix, MatrixI scoresOld, int index) {
        MatrixI J = this.approximateDerivative(repMatrix, scoresOld, index);
        double[] result = this.trf(repMatrix, scoresOld, index, J);
        return result;
    }

    private class TrustRegion {
        private double[] step;
        private double alpha;
        private int iteration;

        public TrustRegion(double[] step, double alpha, int iteration) {
            this.step = step;
            this.alpha = alpha;
            this.iteration = iteration;
        }

        public double[] getStep() {
            return this.step;
        }

        public double getAlpha() {
            return this.alpha;
        }

        public int getIteration() {
            return this.iteration;
        }
    }
}

