/*
 * Decompiled with CFR 0.152.
 */
package dr.evolution.parsimony;

import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.Patterns;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.parsimony.FitchParsimony;
import dr.evolution.parsimony.ParsimonyCriterion;
import dr.evolution.tree.FlexibleNode;
import dr.evolution.tree.FlexibleTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import java.util.TreeSet;

public class SankoffParsimony
implements ParsimonyCriterion {
    private final int stateCount;
    private int[][] stateSets;
    private double[][][] nodeScores;
    private int[][] nodeStates;
    private Tree tree = null;
    private final PatternList patterns;
    private final double[][] costMatrix;
    private final boolean compressStates = true;
    private boolean hasCalculatedSteps = false;
    private boolean hasRecontructedStates = false;
    private final double[] siteScores;

    public SankoffParsimony(PatternList patternList) {
        if (patternList == null) {
            throw new IllegalArgumentException("The patterns cannot be null");
        }
        this.stateCount = patternList.getDataType().getStateCount();
        this.costMatrix = new double[this.stateCount][this.stateCount];
        for (int i = 0; i < this.stateCount; ++i) {
            for (int j = 0; j < this.stateCount; ++j) {
                this.costMatrix[i][j] = i == j ? 0.0 : 1.0;
            }
        }
        this.patterns = patternList;
        this.siteScores = new double[patternList.getPatternCount()];
    }

    public SankoffParsimony(PatternList patternList, double[][] dArray) {
        if (patternList == null) {
            throw new IllegalArgumentException("The patterns cannot be null");
        }
        this.stateCount = patternList.getDataType().getStateCount();
        if (dArray.length != this.stateCount || dArray[0].length != this.stateCount) {
            throw new IllegalArgumentException("The cost matrix is of the wrong dimension: expecting " + this.stateCount + " square");
        }
        this.costMatrix = dArray;
        this.patterns = patternList;
        this.siteScores = new double[patternList.getPatternCount()];
    }

    @Override
    public double[] getSiteScores(Tree tree) {
        if (tree == null) {
            throw new IllegalArgumentException("The tree cannot be null");
        }
        if (this.tree == null || this.tree != tree) {
            this.tree = tree;
            this.initialize();
        }
        if (!this.hasCalculatedSteps) {
            this.calculateSteps(tree, tree.getRoot(), this.patterns);
            for (int i = 0; i < this.siteScores.length; ++i) {
                double[] dArray = this.nodeScores[tree.getRoot().getNumber()][i];
                this.siteScores[i] = this.minScore(dArray, this.stateSets[i]);
            }
            this.hasCalculatedSteps = true;
        }
        return this.siteScores;
    }

    @Override
    public double getScore(Tree tree) {
        this.getSiteScores(tree);
        double d = 0.0;
        for (int i = 0; i < this.patterns.getPatternCount(); ++i) {
            d += this.siteScores[i] * this.patterns.getPatternWeight(i);
        }
        return d;
    }

    @Override
    public int[] getStates(Tree tree, NodeRef nodeRef) {
        this.getSiteScores(tree);
        if (!this.hasRecontructedStates) {
            for (int i = 0; i < this.patterns.getPatternCount(); ++i) {
                this.nodeStates[tree.getRoot().getNumber()][i] = this.minState(this.nodeScores[tree.getRoot().getNumber()][i], this.stateSets[i]);
            }
            this.reconstructStates(tree, tree.getRoot(), this.nodeStates[tree.getRoot().getNumber()]);
            this.hasRecontructedStates = true;
        }
        return this.nodeStates[nodeRef.getNumber()];
    }

    private void initialize() {
        this.hasCalculatedSteps = false;
        this.hasRecontructedStates = false;
        this.stateSets = new int[this.patterns.getPatternCount()][];
        this.nodeScores = new double[this.tree.getNodeCount()][this.patterns.getPatternCount()][];
        this.nodeStates = new int[this.tree.getNodeCount()][this.patterns.getPatternCount()];
        for (int i = 0; i < this.patterns.getPatternCount(); ++i) {
            int n;
            int[] nArray = this.patterns.getPattern(i);
            TreeSet<Integer> treeSet = new TreeSet<Integer>();
            for (int j = 0; j < nArray.length; ++j) {
                boolean[] blArray = this.patterns.getDataType().getStateSet(nArray[j]);
                for (int k = 0; k < blArray.length; ++k) {
                    if (!blArray[k]) continue;
                    treeSet.add(new Integer(k));
                }
            }
            this.stateSets[i] = new int[treeSet.size()];
            Object object = treeSet.iterator();
            int n2 = 0;
            while (object.hasNext()) {
                this.stateSets[i][n2] = (Integer)object.next();
                ++n2;
            }
            for (n = 0; n < this.tree.getExternalNodeCount(); ++n) {
                object = this.tree.getExternalNode(n);
                n2 = nArray[this.patterns.getTaxonIndex(this.tree.getNodeTaxon((NodeRef)object).getId())];
                boolean[] blArray = this.patterns.getDataType().getStateSet(n2);
                this.nodeScores[n][i] = new double[this.stateCount];
                for (int j = 0; j < this.stateCount; ++j) {
                    this.nodeScores[n][i][j] = blArray[j] ? 0.0 : Double.POSITIVE_INFINITY;
                }
            }
            for (n = 0; n < this.tree.getInternalNodeCount(); ++n) {
                this.nodeScores[n + this.tree.getExternalNodeCount()][i] = new double[this.stateCount];
            }
        }
    }

    private void calculateSteps(Tree tree, NodeRef nodeRef, PatternList patternList) {
        if (!tree.isExternal(nodeRef)) {
            int n;
            for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
                this.calculateSteps(tree, tree.getChild(nodeRef, n), patternList);
            }
            for (n = 0; n < patternList.getPatternCount(); ++n) {
                int n2;
                double[] dArray = this.nodeScores[tree.getChild(nodeRef, 0).getNumber()][n];
                double[] dArray2 = this.nodeScores[nodeRef.getNumber()][n];
                int[] nArray = this.stateSets[n];
                for (n2 = 0; n2 < nArray.length; ++n2) {
                    dArray2[nArray[n2]] = this.minCost(n2, dArray, this.costMatrix, nArray);
                }
                for (n2 = 1; n2 < tree.getChildCount(nodeRef); ++n2) {
                    dArray = this.nodeScores[tree.getChild(nodeRef, n2).getNumber()][n];
                    for (int i = 0; i < nArray.length; ++i) {
                        int n3 = nArray[i];
                        dArray2[n3] = dArray2[n3] + this.minCost(i, dArray, this.costMatrix, nArray);
                    }
                }
            }
        }
    }

    private void reconstructStates(Tree tree, NodeRef nodeRef, int[] nArray) {
        int n;
        for (n = 0; n < this.patterns.getPatternCount(); ++n) {
            double[] dArray = this.nodeScores[nodeRef.getNumber()][n];
            int[] nArray2 = this.stateSets[n];
            int n2 = nArray2[0];
            double d = dArray[n2] + this.costMatrix[nArray[n]][n2];
            for (int i = 1; i < nArray2.length; ++i) {
                double d2 = dArray[nArray2[i]] + this.costMatrix[nArray[n]][nArray2[i]];
                if (!(d2 < d)) continue;
                n2 = nArray2[i];
                d = d2;
            }
            this.nodeStates[nodeRef.getNumber()][n] = n2;
        }
        for (n = 0; n < tree.getChildCount(nodeRef); ++n) {
            this.reconstructStates(tree, tree.getChild(nodeRef, n), this.nodeStates[nodeRef.getNumber()]);
        }
    }

    private int minState(double[] dArray) {
        int n = 0;
        for (int i = 1; i < dArray.length; ++i) {
            if (!(dArray[i] < dArray[n])) continue;
            n = i;
        }
        return n;
    }

    private double minScore(double[] dArray) {
        double d = dArray[0];
        for (int i = 1; i < dArray.length; ++i) {
            if (!(dArray[i] < d)) continue;
            d = dArray[i];
        }
        return d;
    }

    private double minCost(int n, double[] dArray, double[][] dArray2) {
        double[] dArray3 = dArray2[n];
        double d = dArray3[0] + dArray[0];
        for (int i = 1; i < dArray.length; ++i) {
            double d2 = dArray3[i] + dArray[i];
            if (!(d2 < d)) continue;
            d = d2;
        }
        return d;
    }

    private int minState(double[] dArray, int[] nArray) {
        int n = nArray[0];
        for (int i = 1; i < nArray.length; ++i) {
            if (!(dArray[nArray[i]] < dArray[n])) continue;
            n = nArray[i];
        }
        return n;
    }

    private double minScore(double[] dArray, int[] nArray) {
        double d = dArray[nArray[0]];
        for (int i = 1; i < nArray.length; ++i) {
            if (!(dArray[nArray[i]] < d)) continue;
            d = dArray[nArray[i]];
        }
        return d;
    }

    private double minCost(int n, double[] dArray, double[][] dArray2, int[] nArray) {
        double[] dArray3 = dArray2[nArray[n]];
        double d = dArray3[nArray[0]] + dArray[nArray[0]];
        for (int i = 1; i < nArray.length; ++i) {
            double d2 = dArray3[nArray[i]] + dArray[nArray[i]];
            if (!(d2 < d)) continue;
            d = d2;
        }
        return d;
    }

    public static void main(String[] stringArray) {
        FlexibleNode flexibleNode = new FlexibleNode(new Taxon("tip1"));
        FlexibleNode flexibleNode2 = new FlexibleNode(new Taxon("tip2"));
        FlexibleNode flexibleNode3 = new FlexibleNode(new Taxon("tip3"));
        FlexibleNode flexibleNode4 = new FlexibleNode(new Taxon("tip4"));
        FlexibleNode flexibleNode5 = new FlexibleNode(new Taxon("tip5"));
        FlexibleNode flexibleNode6 = new FlexibleNode();
        flexibleNode6.addChild(flexibleNode);
        flexibleNode6.addChild(flexibleNode2);
        FlexibleNode flexibleNode7 = new FlexibleNode();
        flexibleNode7.addChild(flexibleNode4);
        flexibleNode7.addChild(flexibleNode5);
        FlexibleNode flexibleNode8 = new FlexibleNode();
        flexibleNode8.addChild(flexibleNode3);
        flexibleNode8.addChild(flexibleNode7);
        FlexibleNode flexibleNode9 = new FlexibleNode();
        flexibleNode9.addChild(flexibleNode6);
        flexibleNode9.addChild(flexibleNode8);
        FlexibleTree flexibleTree = new FlexibleTree(flexibleNode9);
        Patterns patterns = new Patterns(Nucleotides.INSTANCE, flexibleTree);
        patterns.addPattern(new int[]{2, 3, 1, 3, 3});
        FitchParsimony fitchParsimony = new FitchParsimony(patterns, false);
        SankoffParsimony sankoffParsimony = new SankoffParsimony(patterns);
        for (int i = 0; i < patterns.getPatternCount(); ++i) {
            double[] dArray = fitchParsimony.getSiteScores(flexibleTree);
            System.out.println("Pattern = " + i);
            System.out.println("Fitch:");
            System.out.println("  No. Steps = " + dArray[i]);
            System.out.println("    state(node1) = " + fitchParsimony.getStates(flexibleTree, flexibleNode6)[i]);
            System.out.println("    state(node2) = " + fitchParsimony.getStates(flexibleTree, flexibleNode7)[i]);
            System.out.println("    state(node3) = " + fitchParsimony.getStates(flexibleTree, flexibleNode8)[i]);
            System.out.println("    state(root) = " + fitchParsimony.getStates(flexibleTree, flexibleNode9)[i]);
            dArray = sankoffParsimony.getSiteScores(flexibleTree);
            System.out.println("Sankoff:");
            System.out.println("  No. Steps = " + dArray[i]);
            System.out.println("    state(node1) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode6)[i]);
            System.out.println("    state(node2) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode7)[i]);
            System.out.println("    state(node3) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode8)[i]);
            System.out.println("    state(root) = " + sankoffParsimony.getStates(flexibleTree, flexibleNode9)[i]);
            System.out.println();
        }
    }
}

