/*
 * Decompiled with CFR 0.152.
 */
package at.tugraz.genome.lda.utils;

import at.tugraz.genome.lda.exception.LMException;
import at.tugraz.genome.util.FloatMatrix;
import java.io.Serializable;
import java.util.Vector;

public abstract class LevenbergMarquardtOptimizer {
    protected FloatMatrix resultParams_;
    protected double resultChiSqr_;
    protected float resultLambda_;
    protected Float maxDev_;

    protected void fit(float[][] initParameters) throws LMException {
        this.maxDev_ = null;
        FloatMatrix paramsVector = new FloatMatrix(initParameters);
        FloatMatrix resultVector = new FloatMatrix(this.getObservations());
        double previousChiSqr = Double.POSITIVE_INFINITY;
        float lambda = this.getLambdaStartValue();
        float vLambda = this.getLambdaMultiplierStartValue();
        for (int i = 0; i != this.getMaximumOfIterations(); ++i) {
            Vector paramsAndLambda = this.levenbergMarquadtIteration(paramsVector, this.getValues(), resultVector, lambda, vLambda);
            paramsVector = (FloatMatrix)paramsAndLambda.get(0);
            lambda = ((Float)paramsAndLambda.get(1)).floatValue();
            double chiSquared = this.calculateChiSquared(paramsVector, this.getValues(), resultVector);
            if (previousChiSqr == chiSquared && !Double.isInfinite(chiSquared)) {
                previousChiSqr = chiSquared;
                break;
            }
            previousChiSqr = chiSquared;
        }
        this.resultParams_ = paramsVector;
        this.resultChiSqr_ = previousChiSqr;
        this.resultLambda_ = lambda;
    }

    private int getDegreesOfFreedom(FloatMatrix params, FloatMatrix resultVector) {
        int df = resultVector.m - params.m;
        return df;
    }

    private Vector levenbergMarquadtIteration(FloatMatrix params, float[][] values, FloatMatrix resultVector, float lambdaBefore, float vLambda) throws LMException {
        Vector<Serializable> paramsAndLambda = new Vector<Serializable>();
        FloatMatrix jacobian = this.calculateJacobianMatrix(values, params.A);
        FloatMatrix jacobianTransposed = jacobian.transpose();
        FloatMatrix jacobianProduct = jacobianTransposed.times(jacobian);
        FloatMatrix jacobianDiag = this.getDiagonalMatrix(jacobianProduct);
        float lambda = this.detectBestLambda(params, values, resultVector, lambdaBefore, vLambda, jacobianTransposed, jacobianProduct, jacobianDiag);
        FloatMatrix correctedParams = this.solveOneLMCycle(params, values, resultVector, lambda, jacobianTransposed, jacobianProduct, jacobianDiag);
        paramsAndLambda.add(correctedParams);
        paramsAndLambda.add(Float.valueOf(lambda));
        return paramsAndLambda;
    }

    private float detectBestLambda(FloatMatrix params, float[][] values, FloatMatrix resultVector, float lambdaBefore, float vLambda, FloatMatrix jacobianTransposed, FloatMatrix jacobianProduct, FloatMatrix jacobianDiag) throws LMException {
        FloatMatrix currentResidues = this.calculateResidues(resultVector, values, params);
        float sumOfSquaresCurrent = this.calculateSumOfSquares(currentResidues);
        FloatMatrix paramsLambda = this.solveOneLMCycle(params, values, resultVector, lambdaBefore, jacobianTransposed, jacobianProduct, jacobianDiag);
        float sumOfSquaresLambda = Float.NaN;
        if (paramsLambda != null) {
            sumOfSquaresLambda = this.calculateSumOfSquares(this.calculateResidues(resultVector, values, paramsLambda));
        }
        FloatMatrix paramsLambdaSmaller = this.solveOneLMCycle(params, values, resultVector, lambdaBefore / vLambda, jacobianTransposed, jacobianProduct, jacobianDiag);
        float sumOfSquaresLambdaSmaller = Float.NaN;
        if (paramsLambdaSmaller != null) {
            sumOfSquaresLambdaSmaller = this.calculateSumOfSquares(this.calculateResidues(resultVector, values, paramsLambdaSmaller));
        }
        if (!Float.isNaN(sumOfSquaresLambda) && sumOfSquaresLambda < sumOfSquaresCurrent || !Float.isNaN(sumOfSquaresLambdaSmaller) && sumOfSquaresLambdaSmaller < sumOfSquaresCurrent) {
            if (Float.isNaN(sumOfSquaresLambdaSmaller)) {
                return lambdaBefore;
            }
            if (Float.isNaN(sumOfSquaresLambda)) {
                return lambdaBefore / vLambda;
            }
            if (sumOfSquaresLambdaSmaller < sumOfSquaresCurrent) {
                return lambdaBefore / vLambda;
            }
            return lambdaBefore;
        }
        float lambda = lambdaBefore;
        while (Float.isNaN(sumOfSquaresLambda) || sumOfSquaresLambda > sumOfSquaresCurrent) {
            if (Float.isInfinite(lambda *= vLambda)) {
                throw new LMException("The curve cannot be fitted - singular matrix");
            }
            paramsLambda = this.solveOneLMCycle(params, values, resultVector, lambda, jacobianTransposed, jacobianProduct, jacobianDiag);
            sumOfSquaresLambda = Float.NaN;
            if (paramsLambda == null) continue;
            sumOfSquaresLambda = this.calculateSumOfSquares(this.calculateResidues(resultVector, values, paramsLambda));
        }
        return lambda;
    }

    private FloatMatrix solveOneLMCycle(FloatMatrix params, float[][] values, FloatMatrix resultVector, float lambda, FloatMatrix jacobianTransposed, FloatMatrix jacobianProduct, FloatMatrix jacobianDiag) {
        FloatMatrix residues = this.calculateResidues(resultVector, values, params);
        FloatMatrix rightSide = jacobianTransposed.times(residues);
        FloatMatrix lambdaMatrix = new FloatMatrix((float[][])jacobianDiag.A.clone());
        lambdaMatrix = lambdaMatrix.times(lambda);
        FloatMatrix leftSide = jacobianProduct.plus(lambdaMatrix);
        FloatMatrix delta = null;
        delta = leftSide.inverse().times(rightSide);
        FloatMatrix correctedParams = params.plus(delta);
        return correctedParams;
    }

    private FloatMatrix calculateResidues(FloatMatrix resultVector, float[][] values, FloatMatrix paramsVector) {
        float[][] residues = new float[values.length][1];
        FloatMatrix results = this.calculateEquationResults(values, paramsVector);
        for (int i = 0; i != values.length; ++i) {
            residues[i][0] = resultVector.A[i][0] - results.A[i][0];
        }
        return new FloatMatrix(residues);
    }

    private FloatMatrix getDiagonalMatrix(FloatMatrix in) {
        FloatMatrix mx = in.copy();
        for (int i = 0; i != mx.A.length; ++i) {
            for (int j = 0; j != mx.A[i].length; ++j) {
                if (i == j) continue;
                mx.A[i][j] = 0.0f;
            }
        }
        return mx;
    }

    private float calculateSumOfSquares(FloatMatrix res) {
        float sumOfSquares = 0.0f;
        for (int i = 0; i != res.m; ++i) {
            sumOfSquares = (float)((double)sumOfSquares + Math.pow(res.A[i][0], 2.0));
        }
        return sumOfSquares;
    }

    private float calculateChiSquared(FloatMatrix params, float[][] values, FloatMatrix resultVector) {
        FloatMatrix residues = this.calculateResidues(resultVector, values, params);
        FloatMatrix results = this.calculateEquationResults(values, params);
        float chiSquared = 0.0f;
        for (int i = 0; i != values.length; ++i) {
            chiSquared = (float)((double)chiSquared + Math.pow(residues.A[i][0], 2.0) / (double)Math.abs(results.A[i][0]));
        }
        int df = this.getDegreesOfFreedom(params, resultVector);
        return chiSquared /= (float)df;
    }

    public float getMeanDeviation() throws LMException {
        if (this.maxDev_ != null) {
            return this.maxDev_.floatValue();
        }
        FloatMatrix res = this.calculateResidues(new FloatMatrix(this.getObservations()), this.getValues(), this.resultParams_);
        float sumOfSquares = this.calculateSumOfSquares(res);
        return (float)Math.sqrt(sumOfSquares / (float)res.m);
    }

    public void setMaxDeviation(float maxDev) {
        this.maxDev_ = Float.valueOf(maxDev);
    }

    protected float getLambdaStartValue() {
        return 10.0f;
    }

    protected float getLambdaMultiplierStartValue() {
        return 1.5f;
    }

    protected int getMaximumOfIterations() {
        return 10000;
    }

    protected abstract FloatMatrix calculateEquationResults(float[][] var1, FloatMatrix var2);

    protected abstract FloatMatrix calculateJacobianMatrix(float[][] var1, float[][] var2);

    protected abstract float[][] getValues();

    protected abstract float[][] getObservations();

    public abstract void fit() throws LMException;

    public abstract float calculateFitValue(float[] var1) throws LMException;

    public FloatMatrix getResultParams() {
        return this.resultParams_;
    }

    public double getResultChiSqr() {
        return this.resultChiSqr_;
    }

    public float getResultLambda() {
        return this.resultLambda_;
    }
}

