/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.Collections;
import java.util.List;

public class TimeVaryingBranchRateModel
extends AbstractBranchRateModel
implements DifferentiableBranchRates,
Citable {
    private final Tree tree;
    private final Parameter rates;
    private final EpochTimeProvider epochTimeProvider;
    private boolean nodeRatesKnown;
    private boolean storedNodeRatesKnown;
    private double[] nodeRates;
    private double[] storedNodeRates;
    private final FunctionalForm functionalForm;

    public TimeVaryingBranchRateModel(FunctionalForm.Type type, Tree tree, Parameter parameter, Parameter parameter2) {
        this(type, tree, parameter, new EpochTimeProvider.ParameterWrapper(parameter2));
    }

    public TimeVaryingBranchRateModel(FunctionalForm.Type type, Tree tree, Parameter parameter, EpochTimeProvider epochTimeProvider) {
        super("timeVaryingRates");
        this.tree = tree;
        this.rates = parameter;
        this.epochTimeProvider = epochTimeProvider;
        if (tree instanceof TreeModel) {
            this.addModel((TreeModel)tree);
        }
        this.addVariable(parameter);
        this.addModel(epochTimeProvider);
        this.nodeRates = new double[tree.getNodeCount()];
        this.storedNodeRates = new double[tree.getNodeCount()];
        this.functionalForm = type.factory(parameter);
        this.nodeRatesKnown = false;
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        assert (tree == this.tree);
        if (!this.nodeRatesKnown) {
            Traversal.Rates rates = new Traversal.Rates(this.nodeRates, this.functionalForm);
            this.calculateNodeGeneric(rates);
            this.nodeRatesKnown = true;
        }
        return this.nodeRates[this.getParameterIndexFromNode(nodeRef)];
    }

    @Override
    public double[] updateGradientLogDensity(double[] dArray, double[] dArray2, int n, int n2) {
        assert (n == 0);
        assert (n2 == this.rates.getDimension() - 1);
        double[] dArray3 = new double[this.rates.getDimension()];
        Traversal.Gradient gradient = new Traversal.Gradient(dArray3, dArray, this.functionalForm);
        this.calculateNodeGeneric(gradient);
        return dArray3;
    }

    @Override
    public double getBranchRateDifferential(Tree tree, NodeRef nodeRef) {
        return 1.0;
    }

    @Override
    public double getBranchRateSecondDifferential(Tree tree, NodeRef nodeRef) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        this.nodeRatesKnown = false;
        this.fireModelChanged();
        if (model != this.tree) {
            throw new IllegalArgumentException("How did we get here?");
        }
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.nodeRatesKnown = false;
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
        if (this.storedNodeRates == null) {
            this.storedNodeRates = new double[this.nodeRates.length];
        }
        System.arraycopy(this.nodeRates, 0, this.storedNodeRates, 0, this.nodeRates.length);
        this.storedNodeRatesKnown = this.nodeRatesKnown;
    }

    @Override
    protected void restoreState() {
        double[] dArray = this.nodeRates;
        this.nodeRates = this.storedNodeRates;
        this.storedNodeRates = dArray;
        this.nodeRatesKnown = this.storedNodeRatesKnown;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Parameter getRateParameter() {
        return this.rates;
    }

    @Override
    public int getParameterIndexFromNode(NodeRef nodeRef) {
        int n = nodeRef.getNumber();
        if (n > this.tree.getRoot().getNumber()) {
            --n;
        }
        return n;
    }

    @Override
    public ArbitraryBranchRates.BranchRateTransform getTransform() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] updateDiagonalHessianLogDensity(double[] dArray, double[] dArray2, double[] dArray3, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.MOLECULAR_CLOCK;
    }

    @Override
    public String getDescription() {
        return "Time-varying branch rate model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(new Citation(new Author[]{new Author("P", "Datta"), new Author("P", "Lemey"), new Author("MA", "Suchard")}, Citation.Status.IN_PREPARATION));
    }

    @Override
    public String toString() {
        TreeTraitProvider[] treeTraitProviderArray = new TreeTraitProvider[]{this};
        return TreeUtils.newick(this.tree, treeTraitProviderArray);
    }

    private void calculateNodeGeneric(Traversal traversal) {
        double[] dArray = this.epochTimeProvider.getEpochTimes();
        NodeRef nodeRef = this.tree.getRoot();
        double d = this.tree.getNodeHeight(nodeRef);
        int n = dArray.length - 1;
        while (dArray[n] >= d) {
            --n;
        }
        this.traverseTreeByBranchGeneric(dArray, d, this.tree.getChild(nodeRef, 0), n, traversal);
        this.traverseTreeByBranchGeneric(dArray, d, this.tree.getChild(nodeRef, 1), n, traversal);
    }

    private void traverseTreeByBranchGeneric(double[] dArray, double d, NodeRef nodeRef, int n, Traversal traversal) {
        double d2 = this.tree.getNodeHeight(nodeRef);
        double d3 = d - d2;
        int n2 = this.getParameterIndexFromNode(nodeRef);
        traversal.reset();
        if (d > d2) {
            while (dArray[n] > d2) {
                traversal.increment(n, n2, d, dArray[n], d3);
                d = dArray[n];
                --n;
            }
            traversal.increment(n, n2, d, d2, d3);
        }
        traversal.store(n, n2, d3);
        if (!this.tree.isExternal(nodeRef)) {
            this.traverseTreeByBranchGeneric(dArray, d2, this.tree.getChild(nodeRef, 0), n, traversal);
            this.traverseTreeByBranchGeneric(dArray, d2, this.tree.getChild(nodeRef, 1), n, traversal);
        }
    }

    public static interface EpochTimeProvider
    extends Model {
        public double[] getEpochTimes();

        public static class IntervalWrapper
        extends AbstractEpochTimeProvider {
            private final BigFastTreeIntervals intervals;

            public IntervalWrapper(BigFastTreeIntervals bigFastTreeIntervals) {
                super("IntervalWrapper");
                this.intervals = bigFastTreeIntervals;
                this.addModel(bigFastTreeIntervals);
            }

            @Override
            void computeTimes() {
                int n = this.intervals.getIntervalCount();
                if (this.times == null) {
                    this.times = new double[n];
                }
                for (int i = 0; i < n; ++i) {
                    this.times[i] = this.intervals.getIntervalTime(i);
                }
            }

            @Override
            protected void handleModelChangedEvent(Model model, Object object, int n) {
                assert (model == this.intervals);
                this.timesKnown = false;
                this.fireModelChanged();
            }

            @Override
            protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
                throw new IllegalArgumentException("Should not be called");
            }

            @Override
            protected void storeState() {
            }

            @Override
            protected void restoreState() {
            }
        }

        public static class ParameterWrapper
        extends AbstractEpochTimeProvider {
            private final Parameter epochTimes;

            public ParameterWrapper(Parameter parameter) {
                super("ParameterWrapper");
                this.epochTimes = parameter;
                this.addVariable(parameter);
            }

            @Override
            void computeTimes() {
                if (this.times == null) {
                    this.times = new double[this.epochTimes.getDimension() + 1];
                }
                System.arraycopy(this.epochTimes.getParameterValues(), 0, this.times, 1, this.epochTimes.getDimension());
            }

            @Override
            protected void handleModelChangedEvent(Model model, Object object, int n) {
                throw new IllegalArgumentException("Should not be called");
            }

            @Override
            protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
                assert (variable == this.epochTimes);
                this.timesKnown = false;
                this.fireModelChanged();
            }

            @Override
            protected void storeState() {
            }

            @Override
            protected void restoreState() {
                this.timesKnown = false;
            }
        }

        public static abstract class AbstractEpochTimeProvider
        extends AbstractModel
        implements EpochTimeProvider {
            double[] times;
            boolean timesKnown;

            public AbstractEpochTimeProvider(String string) {
                super(string);
            }

            @Override
            protected void acceptState() {
            }

            @Override
            public double[] getEpochTimes() {
                if (!this.timesKnown) {
                    this.computeTimes();
                    this.timesKnown = true;
                }
                return this.times;
            }

            abstract void computeTimes();
        }
    }

    public static interface FunctionalForm {
        public void reset();

        public void incrementRate(int var1, double var2, double var4);

        public double gradientWeight(int var1, double var2, double var4, double var6);

        public double getRateParameter(int var1);

        public double rateNumerator();

        public static abstract class Integrable
        extends Base {
            Integrable(Parameter parameter) {
                super(parameter);
            }
        }

        public static abstract class PiecewiseLinear
        implements FunctionalForm {
        }

        public static class PiecewiseLogConstant
        extends PiecewiseConstant {
            PiecewiseLogConstant(Parameter parameter) {
                super(parameter);
            }

            @Override
            public double getRateParameter(int n) {
                return Math.exp(super.getRateParameter(n));
            }

            @Override
            public double gradientWeight(int n, double d, double d2, double d3) {
                return super.gradientWeight(n, d, d2, d3) * this.getRateParameter(n);
            }
        }

        public static class PiecewiseConstant
        extends Base {
            private double branchRateNumerator;

            PiecewiseConstant(Parameter parameter) {
                super(parameter);
            }

            @Override
            public void reset() {
                this.branchRateNumerator = 0.0;
            }

            @Override
            public double getRateParameter(int n) {
                return this.parameter.getParameterValue(n);
            }

            @Override
            public void incrementRate(int n, double d, double d2) {
                double d3 = d - d2;
                this.branchRateNumerator += this.getRateParameter(n) * d3;
            }

            @Override
            public double rateNumerator() {
                return this.branchRateNumerator;
            }

            @Override
            public double gradientWeight(int n, double d, double d2, double d3) {
                double d4 = d - d2;
                return d4 / d3;
            }
        }

        public static abstract class Base
        implements FunctionalForm {
            final Parameter parameter;

            Base(Parameter parameter) {
                this.parameter = parameter;
            }
        }

        public static enum Type {
            PIECEWISE_CONSTANT("piecewiseConstant"){

                @Override
                FunctionalForm factory(Parameter parameter) {
                    return new PiecewiseConstant(parameter);
                }
            }
            ,
            PIECEWISE_LOG_CONSTANT("piecewiseLogConstant"){

                @Override
                FunctionalForm factory(Parameter parameter) {
                    return new PiecewiseLogConstant(parameter);
                }
            };

            private final String name;

            private Type(String string2) {
                this.name = string2;
            }

            public String getName() {
                return this.name;
            }

            abstract FunctionalForm factory(Parameter var1);

            public static Type parse(String string) {
                for (Type type : Type.values()) {
                    if (!type.name.equalsIgnoreCase(string)) continue;
                    return type;
                }
                throw new IllegalArgumentException("Unknown FunctionalForm.Type");
            }
        }
    }

    static interface Traversal {
        public void reset();

        public void increment(int var1, int var2, double var3, double var5, double var7);

        public void store(int var1, int var2, double var3);

        public static class Rates
        extends AbstractTraversal {
            private final double[] nodeRates;

            Rates(double[] dArray, FunctionalForm functionalForm) {
                super(functionalForm);
                this.nodeRates = dArray;
            }

            @Override
            public void increment(int n, int n2, double d, double d2, double d3) {
                this.functionalForm.incrementRate(n, d, d2);
            }

            @Override
            public void store(int n, int n2, double d) {
                this.nodeRates[n2] = this.functionalForm.rateNumerator() / d;
            }
        }

        public static class Gradient
        extends AbstractTraversal {
            private final double[] gradientEpochs;
            private final double[] gradientNodes;

            Gradient(double[] dArray, double[] dArray2, FunctionalForm functionalForm) {
                super(functionalForm);
                this.gradientEpochs = dArray;
                this.gradientNodes = dArray2;
            }

            @Override
            public void increment(int n, int n2, double d, double d2, double d3) {
                int n3 = n;
                this.gradientEpochs[n3] = this.gradientEpochs[n3] + this.gradientNodes[n2] * this.functionalForm.gradientWeight(n, d, d2, d3);
            }

            @Override
            public void store(int n, int n2, double d) {
            }
        }

        public static abstract class AbstractTraversal
        implements Traversal {
            final FunctionalForm functionalForm;

            AbstractTraversal(FunctionalForm functionalForm) {
                this.functionalForm = functionalForm;
            }

            @Override
            public void reset() {
                this.functionalForm.reset();
            }
        }
    }
}

