/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.TaxonList;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class LocalClockModel
extends AbstractBranchRateModel
implements Citable,
DifferentiableBranchRates {
    private TreeModel treeModel;
    protected Map<Integer, LocalClock> localTipClocks = new HashMap<Integer, LocalClock>();
    protected Map<BitSet, LocalClock> localCladeClocks = new HashMap<BitSet, LocalClock>();
    protected LocalClock trunkClock = null;
    private boolean updateNodeClocks = true;
    private Map<NodeRef, LocalClock> nodeClockMap = new HashMap<NodeRef, LocalClock>();
    private final Parameter globalRateParameter;
    private final BranchRateModel globalBranchRates;
    private final TreeTraitProvider.Helper helper = new TreeTraitProvider.Helper();
    public static Citation CITATION = new Citation(new Author[]{new Author("AD", "Yoder"), new Author("Z", "Yang")}, "Estimation of Primate Speciation Dates Using Local Molecular Clocks", 2000, "Mol Biol Evol", 17, 1081, 1090);

    public LocalClockModel(TreeModel treeModel, Parameter parameter) {
        super("localClockModel");
        this.treeModel = treeModel;
        this.addModel(treeModel);
        this.globalRateParameter = parameter;
        this.globalBranchRates = null;
        this.addVariable(parameter);
        this.helper.addTrait(this);
        this.updateNodeClocks = true;
    }

    public LocalClockModel(TreeModel treeModel, BranchRateModel branchRateModel) {
        super("localClockModel");
        this.treeModel = treeModel;
        this.addModel(treeModel);
        this.globalRateParameter = null;
        this.globalBranchRates = branchRateModel;
        this.addModel(branchRateModel);
        this.helper.addTrait(this);
        this.updateNodeClocks = true;
    }

    public void addExternalBranchClock(TaxonList taxonList, Parameter parameter, boolean bl) throws TreeUtils.MissingTaxonException {
        Set<Integer> set = TreeUtils.getTipsForTaxa(this.treeModel, taxonList);
        LocalClock localClock = new LocalClock(parameter, bl, set, ClockType.EXTERNAL);
        for (int n : set) {
            this.localTipClocks.put(n, localClock);
        }
        this.addVariable(parameter);
    }

    public void addExternalBranchClock(TaxonList taxonList, BranchRateModel branchRateModel, boolean bl) throws TreeUtils.MissingTaxonException {
        Set<Integer> set = TreeUtils.getTipsForTaxa(this.treeModel, taxonList);
        LocalClock localClock = new LocalClock(branchRateModel, bl, set, ClockType.EXTERNAL);
        for (int n : set) {
            this.localTipClocks.put(n, localClock);
        }
        this.addModel(branchRateModel);
    }

    public void addCladeClock(TaxonList taxonList, Parameter parameter, boolean bl, Parameter parameter2, boolean bl2, boolean bl3) throws TreeUtils.MissingTaxonException {
        Set<Integer> set = TreeUtils.getTipsForTaxa(this.treeModel, taxonList);
        BitSet bitSet = TreeUtils.getTipsBitSetForTaxa(this.treeModel, taxonList);
        LocalClock localClock = new LocalClock(parameter, bl, set, parameter2, bl2, bl3);
        this.localCladeClocks.put(bitSet, localClock);
        this.addVariable(parameter);
        if (parameter2 != null) {
            this.addVariable(parameter2);
        }
    }

    public void addCladeClock(TaxonList taxonList, BranchRateModel branchRateModel, boolean bl, Parameter parameter, boolean bl2, boolean bl3) throws TreeUtils.MissingTaxonException {
        Set<Integer> set = TreeUtils.getTipsForTaxa(this.treeModel, taxonList);
        BitSet bitSet = TreeUtils.getTipsBitSetForTaxa(this.treeModel, taxonList);
        LocalClock localClock = new LocalClock(branchRateModel, bl, set, parameter, bl2, bl3);
        this.localCladeClocks.put(bitSet, localClock);
        this.addModel(branchRateModel);
        if (parameter != null) {
            this.addVariable(parameter);
        }
    }

    public void addTrunkClock(TaxonList taxonList, Parameter parameter, Parameter parameter2, boolean bl) throws TreeUtils.MissingTaxonException {
        if (this.trunkClock != null) {
            throw new RuntimeException("Trunk already defined for this LocalClockModel");
        }
        ArrayList<Integer> arrayList = new ArrayList<Integer>(TreeUtils.getTipsForTaxa(this.treeModel, taxonList));
        this.trunkClock = new LocalClock(parameter, parameter2, bl, arrayList, ClockType.TRUNK);
        this.addVariable(parameter);
        if (parameter2 != null) {
            this.addVariable(parameter2);
        }
        this.helper.addTrait("trunk", new TreeTrait.S(){

            @Override
            public String getTraitName() {
                return "trunk";
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            @Override
            public String getTrait(Tree tree, NodeRef nodeRef) {
                LocalClockModel.this.setupNodeClocks(tree);
                if (LocalClockModel.this.nodeClockMap.get(nodeRef) == LocalClockModel.this.trunkClock) {
                    return "T";
                }
                return "B";
            }
        });
    }

    public void addTrunkClock(TaxonList taxonList, BranchRateModel branchRateModel, Parameter parameter, boolean bl) throws TreeUtils.MissingTaxonException {
        if (this.trunkClock != null) {
            throw new RuntimeException("Trunk already defined for this LocalClockModel");
        }
        ArrayList<Integer> arrayList = new ArrayList<Integer>(TreeUtils.getTipsForTaxa(this.treeModel, taxonList));
        this.trunkClock = new LocalClock(branchRateModel, parameter, bl, arrayList, ClockType.TRUNK);
        this.addModel(branchRateModel);
        if (parameter != null) {
            this.addVariable(parameter);
        }
        this.helper.addTrait("trunk", new TreeTrait.S(){

            @Override
            public String getTraitName() {
                return "trunk";
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            @Override
            public String getTrait(Tree tree, NodeRef nodeRef) {
                LocalClockModel.this.setupNodeClocks(tree);
                if (LocalClockModel.this.nodeClockMap.get(nodeRef) == LocalClockModel.this.trunkClock) {
                    return "T";
                }
                return "B";
            }
        });
    }

    @Override
    public void handleModelChangedEvent(Model model, Object object, int n) {
        this.updateNodeClocks = true;
        this.fireModelChanged();
    }

    @Override
    protected final void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (this.trunkClock != null && variable == this.trunkClock.indexParameter) {
            this.updateNodeClocks = true;
        }
        this.fireModelChanged();
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
        this.updateNodeClocks = true;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return this.helper.getTreeTraits();
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        return this.helper.getTreeTrait(string);
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        if (tree.isRoot(nodeRef)) {
            throw new IllegalArgumentException("root node doesn't have a rate!");
        }
        this.setupNodeClocks(tree);
        double d = this.globalRateParameter != null ? this.globalRateParameter.getParameterValue(0) : this.globalBranchRates.getBranchRate(tree, nodeRef);
        LocalClock localClock = this.nodeClockMap.get(tree.getParent(nodeRef));
        LocalClock localClock2 = this.nodeClockMap.get(nodeRef);
        if (localClock2 != null) {
            double d2 = d;
            double d3 = 1.0;
            double d4 = 0.0;
            if (localClock2 != localClock) {
                if (localClock != null) {
                    d2 = localClock.isRelativeRate() ? (d2 *= localClock2.getBranchRate(tree, tree.getParent(nodeRef))) : localClock2.getBranchRate(tree, tree.getParent(nodeRef));
                }
                if (localClock2.stemAsTime) {
                    d4 = localClock2.getStemValue();
                    d3 = d4 / tree.getBranchLength(nodeRef);
                    if (d3 > 1.0) {
                        throw new IllegalArgumentException("A stem proportion for a local clock is > 1.0");
                    }
                } else {
                    d3 = localClock2.getStemValue();
                    d4 = tree.getBranchLength(nodeRef) * d3;
                }
            }
            d = localClock2.isRelativeRate() ? (d *= localClock2.getBranchRate(tree, nodeRef)) : localClock2.getBranchRate(tree, nodeRef);
            d = d * d3 + d2 * (1.0 - d3);
        }
        return d;
    }

    private void setupNodeClocks(Tree tree) {
        if (this.updateNodeClocks) {
            this.nodeClockMap.clear();
            this.setupRateParameters(tree, tree.getRoot(), new BitSet());
            if (this.trunkClock != null) {
                this.setupTrunkRates(tree, tree.getRoot());
            }
            this.updateNodeClocks = false;
        }
    }

    private void setupRateParameters(Tree tree, NodeRef nodeRef, BitSet bitSet) {
        LocalClock localClock;
        if (tree.isExternal(nodeRef)) {
            bitSet.set(nodeRef.getNumber());
            localClock = this.localTipClocks.get(nodeRef.getNumber());
        } else {
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = tree.getChild(nodeRef, i);
                BitSet bitSet2 = new BitSet();
                this.setupRateParameters(tree, nodeRef2, bitSet2);
                bitSet.or(bitSet2);
            }
            localClock = this.localCladeClocks.get(bitSet);
        }
        if (localClock != null) {
            this.setNodeClock(tree, nodeRef, localClock, localClock.excludeClade());
        }
    }

    private boolean setupTrunkRates(Tree tree, NodeRef nodeRef) {
        LocalClock localClock;
        block5: {
            block3: {
                block4: {
                    localClock = null;
                    if (!tree.isExternal(nodeRef)) break block3;
                    if (this.trunkClock.indexParameter == null) break block4;
                    if (((Integer)this.trunkClock.tipList.get((int)this.trunkClock.indexParameter.getParameterValue(0))).intValue() == nodeRef.getNumber()) {
                        localClock = this.trunkClock;
                    }
                    break block5;
                }
                if (!this.trunkClock.tipList.contains(nodeRef.getNumber())) break block5;
                localClock = this.trunkClock;
                break block5;
            }
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = tree.getChild(nodeRef, i);
                if (!this.setupTrunkRates(tree, nodeRef2)) continue;
                localClock = this.trunkClock;
            }
        }
        if (localClock != null) {
            this.setNodeClock(tree, nodeRef, localClock, localClock.excludeClade());
            return true;
        }
        return false;
    }

    private void setNodeClock(Tree tree, NodeRef nodeRef, LocalClock localClock, boolean bl) {
        if (!tree.isExternal(nodeRef) && !bl) {
            for (int i = 0; i < tree.getChildCount(nodeRef); ++i) {
                NodeRef nodeRef2 = tree.getChild(nodeRef, i);
                this.setNodeClock(tree, nodeRef2, localClock, false);
            }
        }
        if (!this.nodeClockMap.containsKey(nodeRef)) {
            this.nodeClockMap.put(nodeRef, localClock);
        }
    }

    @Override
    public double getBranchRateDifferential(Tree tree, NodeRef nodeRef) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public double getBranchRateSecondDifferential(Tree tree, NodeRef nodeRef) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public Parameter getRateParameter() {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public int getParameterIndexFromNode(NodeRef nodeRef) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public ArbitraryBranchRates.BranchRateTransform getTransform() {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public double[] updateGradientLogDensity(double[] dArray, double[] dArray2, int n, int n2) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public double[] updateDiagonalHessianLogDensity(double[] dArray, double[] dArray2, double[] dArray3, int n, int n2) {
        throw new RuntimeException("Not yet implemented!");
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.MOLECULAR_CLOCK;
    }

    @Override
    public String getDescription() {
        return "Local clock model";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }

    private class LocalClock {
        private final Parameter rateParameter;
        private final BranchRateModel branchRates;
        private final Parameter indexParameter;
        private final boolean isRelativeRate;
        private final Set<Integer> tips;
        private final List<Integer> tipList;
        private final ClockType type;
        private final Parameter stemParameter;
        private final boolean stemAsTime;
        private final boolean excludeClade;

        LocalClock(Parameter parameter, boolean bl, Set<Integer> set, ClockType clockType) {
            this.rateParameter = parameter;
            this.branchRates = null;
            this.indexParameter = null;
            this.isRelativeRate = bl;
            this.tips = set;
            this.tipList = null;
            this.type = clockType;
            this.stemParameter = null;
            this.stemAsTime = false;
            this.excludeClade = true;
        }

        LocalClock(BranchRateModel branchRateModel, boolean bl, Set<Integer> set, ClockType clockType) {
            this.rateParameter = null;
            this.branchRates = branchRateModel;
            this.indexParameter = null;
            this.isRelativeRate = bl;
            this.tips = set;
            this.tipList = null;
            this.type = clockType;
            this.stemParameter = null;
            this.stemAsTime = false;
            this.excludeClade = true;
        }

        LocalClock(Parameter parameter, boolean bl, Set<Integer> set, Parameter parameter2, boolean bl2, boolean bl3) {
            this.rateParameter = parameter;
            this.branchRates = null;
            this.indexParameter = null;
            this.isRelativeRate = bl;
            this.tips = set;
            this.tipList = null;
            this.type = ClockType.CLADE;
            this.stemParameter = parameter2;
            this.stemAsTime = bl2;
            this.excludeClade = bl3;
        }

        LocalClock(BranchRateModel branchRateModel, boolean bl, Set<Integer> set, Parameter parameter, boolean bl2, boolean bl3) {
            this.rateParameter = null;
            this.branchRates = branchRateModel;
            this.indexParameter = null;
            this.isRelativeRate = bl;
            this.tips = set;
            this.tipList = null;
            this.type = ClockType.CLADE;
            this.stemParameter = parameter;
            this.stemAsTime = bl2;
            this.excludeClade = bl3;
        }

        LocalClock(Parameter parameter, Parameter parameter2, boolean bl, List<Integer> list, ClockType clockType) {
            this.rateParameter = parameter;
            this.branchRates = null;
            this.indexParameter = parameter2;
            this.isRelativeRate = bl;
            this.tips = null;
            this.tipList = list;
            this.type = clockType;
            this.stemParameter = null;
            this.stemAsTime = false;
            this.excludeClade = true;
        }

        LocalClock(BranchRateModel branchRateModel, Parameter parameter, boolean bl, List<Integer> list, ClockType clockType) {
            this.rateParameter = null;
            this.branchRates = branchRateModel;
            this.indexParameter = parameter;
            this.isRelativeRate = bl;
            this.tips = null;
            this.tipList = list;
            this.type = clockType;
            this.stemParameter = null;
            this.stemAsTime = false;
            this.excludeClade = true;
        }

        double getStemValue() {
            if (this.stemParameter != null) {
                return this.stemParameter.getParameterValue(0);
            }
            return 0.0;
        }

        boolean excludeClade() {
            return this.excludeClade;
        }

        ClockType getType() {
            return this.type;
        }

        boolean isRelativeRate() {
            return this.isRelativeRate;
        }

        double getBranchRate(Tree tree, NodeRef nodeRef) {
            if (this.rateParameter != null) {
                return this.rateParameter.getParameterValue(0);
            }
            return this.branchRates.getBranchRate(tree, nodeRef);
        }
    }

    static enum ClockType {
        CLADE,
        TRUNK,
        EXTERNAL;

    }
}

