/*
 * Decompiled with CFR 0.152.
 */
package org.ohdsi.likelihood;

import com.github.lbfgs4j.LbfgsMinimizer;
import com.github.lbfgs4j.liblbfgs.Function;
import com.github.lbfgs4j.liblbfgs.Lbfgs;
import com.github.lbfgs4j.liblbfgs.LbfgsConstant;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.util.Arrays;
import org.apache.commons.math.util.FastMath;
import org.ohdsi.data.SccsData;

public class SccsPartialLikelihood
extends AbstractModelLikelihood {
    private static final long serialVersionUID = 5911070778889767445L;
    private final Parameter beta;
    private final SccsData data;
    private final double n;
    private final double[] exps;
    private final double[] xps;
    private final int[] idxs;
    private boolean likelihoodKnown;
    private double minLogLikelihood;
    private boolean storedLikelihoodKnown;
    private double storedMinLogLikelihood;
    private double[] minGradient;

    public SccsPartialLikelihood(Parameter beta, SccsData data) {
        super("SccsPartialLikelihood");
        this.beta = beta;
        this.data = data;
        this.n = data.y.length;
        this.addVariable((Variable)beta);
        int stratumSize = 1;
        int maxStratumSize = 0;
        int i = 1;
        while ((double)i < this.n) {
            if (data.stratumId[i - 1] != data.stratumId[i]) {
                if (stratumSize > maxStratumSize) {
                    maxStratumSize = stratumSize;
                }
                stratumSize = 0;
            }
            if (data.y[i] != 0) {
                ++stratumSize;
            }
            ++i;
        }
        if (stratumSize > maxStratumSize) {
            maxStratumSize = stratumSize;
        }
        this.exps = new double[maxStratumSize];
        this.xps = new double[maxStratumSize];
        this.idxs = new int[maxStratumSize];
        this.minGradient = new double[data.x[0].length];
        this.likelihoodKnown = false;
    }

    private void computeLogLikelihoodAndGradient(double pA, double[] pX) {
        int cursor = 0;
        double sumExps = 0.0;
        double[] sumExpXs = new double[pX.length];
        Arrays.fill(sumExpXs, 0.0);
        this.minLogLikelihood = 0.0;
        Arrays.fill(this.minGradient, 0.0);
        int i = 0;
        while ((double)i < this.n) {
            double xp = this.data.a[i] * pA;
            int j = 0;
            while (j < pX.length) {
                xp += this.data.x[i][j] * pX[j];
                ++j;
            }
            double exp = this.data.time[i] * FastMath.exp((double)xp);
            sumExps += exp;
            int j2 = 0;
            while (j2 < pX.length) {
                int n = j2;
                sumExpXs[n] = sumExpXs[n] + this.data.x[i][j2] * exp;
                ++j2;
            }
            if (this.data.y[i] != 0) {
                this.xps[cursor] = xp;
                this.exps[cursor] = exp;
                this.idxs[cursor] = i;
                ++cursor;
            }
            if ((double)i == this.n - 1.0 || this.data.stratumId[i] != this.data.stratumId[i + 1]) {
                double logDenominator = FastMath.log((double)sumExps);
                int j3 = 0;
                while (j3 < cursor) {
                    int idx = this.idxs[j3];
                    this.minLogLikelihood -= (double)this.data.y[idx] * (FastMath.log((double)this.exps[j3]) - logDenominator);
                    int k = 0;
                    while (k < pX.length) {
                        double part1 = (double)this.data.y[idx] * FastMath.exp((double)(-this.xps[j3])) * sumExps / this.data.time[idx];
                        double part2 = this.data.x[idx][k] * this.exps[j3] / sumExps;
                        double part3 = this.exps[j3] * sumExpXs[k] / FastMath.pow((double)sumExps, (double)2.0);
                        int n = k++;
                        this.minGradient[n] = this.minGradient[n] - part1 * (part2 - part3);
                    }
                    ++j3;
                }
                cursor = 0;
                sumExps = 0.0;
                Arrays.fill(sumExpXs, 0.0);
            }
            ++i;
        }
    }

    private void refitModelAtNewBeta() {
        OptimizableFuntion function = new OptimizableFuntion();
        LbfgsConstant.LBFGS_Param params = Lbfgs.defaultParams();
        params.epsilon = 1.0E-6;
        params.delta = 1.0E-6;
        LbfgsMinimizer minimizer = new LbfgsMinimizer(params, false);
        minimizer.minimize((Function)function);
    }

    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.beta) {
            this.likelihoodKnown = false;
        }
    }

    protected void storeState() {
        this.storedLikelihoodKnown = this.likelihoodKnown;
        this.storedMinLogLikelihood = this.minLogLikelihood;
    }

    protected void restoreState() {
        this.likelihoodKnown = this.storedLikelihoodKnown;
        this.minLogLikelihood = this.storedMinLogLikelihood;
    }

    protected void acceptState() {
    }

    public Model getModel() {
        return this;
    }

    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.refitModelAtNewBeta();
            this.likelihoodKnown = true;
        }
        return -this.minLogLikelihood;
    }

    public void makeDirty() {
        this.likelihoodKnown = false;
    }

    protected void handleModelChangedEvent(Model arg0, Object arg1, int arg2) {
    }

    public static void main(String[] args) {
        int[] nArray = new int[38];
        nArray[0] = 1;
        nArray[5] = 1;
        nArray[9] = 1;
        nArray[13] = 1;
        nArray[18] = 1;
        nArray[22] = 1;
        nArray[25] = 1;
        nArray[29] = 1;
        nArray[35] = 1;
        nArray[36] = 1;
        int[] y = nArray;
        double[] time = new double[]{107.0, 21.0, 54.0, 183.0, 41.0, 21.0, 120.0, 183.0, 78.0, 21.0, 83.0, 183.0, 82.0, 21.0, 79.0, 183.0, 81.0, 21.0, 80.0, 183.0, 44.0, 21.0, 117.0, 183.0, 119.0, 21.0, 42.0, 183.0, 145.0, 21.0, 16.0, 183.0, 77.0, 21.0, 84.0, 183.0, 182.0, 183.0};
        double[] a = new double[]{0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0};
        double[][] x = new double[][]{{0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {0.0}, {0.0}, {1.0}, {0.0}, {1.0}};
        int[] stratumId = new int[]{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9, 10, 10};
        SccsData data = new SccsData(y, a, x, stratumId, time);
        Parameter.Default parameter = new Parameter.Default(-9999.9);
        SccsPartialLikelihood sccs = new SccsPartialLikelihood((Parameter)parameter, data);
        System.err.println(sccs.getLogLikelihood());
        parameter.setParameterValue(0, 2.487975);
        sccs.makeDirty();
        System.err.println(sccs.getLogLikelihood());
    }

    class OptimizableFuntion
    implements Function {
        double[] lastPoint;

        public OptimizableFuntion() {
            this.lastPoint = new double[SccsPartialLikelihood.this.minGradient.length];
            Arrays.fill(this.lastPoint, 9999.0);
        }

        private void updateIfNeeded(double[] point) {
            boolean needUpdate = false;
            int i = 0;
            while (i < point.length) {
                if (this.lastPoint[i] != point[i]) {
                    needUpdate = true;
                    break;
                }
                ++i;
            }
            if (needUpdate) {
                SccsPartialLikelihood.this.computeLogLikelihoodAndGradient(SccsPartialLikelihood.this.beta.getParameterValue(0), point);
            }
        }

        public int getDimension() {
            return SccsPartialLikelihood.this.minGradient.length;
        }

        public double[] gradientAt(double[] arg0) {
            this.updateIfNeeded(arg0);
            return SccsPartialLikelihood.this.minGradient;
        }

        public double valueAt(double[] arg0) {
            this.updateIfNeeded(arg0);
            return SccsPartialLikelihood.this.minLogLikelihood;
        }
    }
}

