/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.tree;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxa;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeStatistic;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

public class ParaphylyRateStatistic
extends TreeStatistic {
    public static final String PARAPHYLY_RATE_STATISTIC = "paraphylyRateStatistic";
    public static final String PARAPHYLY_LIST = "paraphylyList";
    public static final String WEIGHTING = "weighting";
    private List<Taxa> paraphylySet;
    private DifferentiableBranchRates branchRateModel;
    private Tree tree;
    private double totalTime;
    private List<NodeRef> MRCANodeList;
    private int dim;
    private BranchWeighting branchWeighting;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(DifferentiableBranchRates.class), new ElementRule(TreeModel.class), new ElementRule("paraphylyList", new XMLSyntaxRule[]{new ElementRule(Taxa.class, 1, Integer.MAX_VALUE)}), AttributeRule.newStringRule("weighting", true)};

        @Override
        public String getParserName() {
            return ParaphylyRateStatistic.PARAPHYLY_RATE_STATISTIC;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.getAttribute("name", xMLObject.getId());
            ArrayList arrayList = new ArrayList();
            DifferentiableBranchRates differentiableBranchRates = (DifferentiableBranchRates)xMLObject.getChild(DifferentiableBranchRates.class);
            TreeModel treeModel = (TreeModel)xMLObject.getChild(TreeModel.class);
            ArrayList<Taxa> arrayList2 = new ArrayList<Taxa>();
            if (xMLObject.hasChildNamed(ParaphylyRateStatistic.PARAPHYLY_LIST)) {
                XMLObject xMLObject2 = xMLObject.getChild(ParaphylyRateStatistic.PARAPHYLY_LIST);
                for (int i = 0; i < xMLObject2.getChildCount(); ++i) {
                    Taxa taxa = (Taxa)xMLObject2.getChild(i);
                    arrayList2.add(taxa);
                }
            }
            int n = arrayList2.size();
            BranchWeighting branchWeighting = this.parseWeighting(xMLObject);
            ParaphylyRateStatistic paraphylyRateStatistic = new ParaphylyRateStatistic(string, differentiableBranchRates, treeModel, arrayList2, branchWeighting, n);
            return paraphylyRateStatistic;
        }

        private BranchWeighting parseWeighting(XMLObject xMLObject) throws XMLParseException {
            String string = BranchWeighting.NONE.getName();
            String string2 = xMLObject.getAttribute(ParaphylyRateStatistic.WEIGHTING, string);
            BranchWeighting branchWeighting = BranchWeighting.parse(string2);
            if (branchWeighting == null) {
                throw new XMLParseException("Unknown weighting type");
            }
            return branchWeighting;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Reports weighted or unweighted average branch-rate parameter across complete disjoint set of user-specified paraphylies";
        }

        @Override
        public Class getReturnType() {
            return ParaphylyRateStatistic.class;
        }
    };

    public ParaphylyRateStatistic(String string, DifferentiableBranchRates differentiableBranchRates, TreeModel treeModel, List<Taxa> list, BranchWeighting branchWeighting, int n) {
        super(string);
        this.branchRateModel = differentiableBranchRates;
        this.paraphylySet = list;
        this.tree = treeModel;
        this.branchWeighting = branchWeighting;
        this.dim = n;
        this.MRCANodeList = new ArrayList<NodeRef>(n);
        for (int i = 0; i < n; ++i) {
            this.MRCANodeList.add(null);
        }
    }

    @Override
    public void setTree(Tree tree) {
        this.tree = tree;
    }

    @Override
    public Tree getTree() {
        return this.tree;
    }

    @Override
    public int getDimension() {
        return this.dim;
    }

    @Override
    public double getStatisticValue(int n) {
        this.updateMRCAList();
        ArrayList<NodeRef> arrayList = new ArrayList<NodeRef>(this.dim - 1);
        for (int i = 0; i < this.dim; ++i) {
            if (i == n) continue;
            arrayList.add(this.MRCANodeList.get(i));
        }
        this.totalTime = 0.0;
        double d = this.recurseToAccumulateRate(this.MRCANodeList.get(n), arrayList);
        return d / this.totalTime;
    }

    private void updateMRCAList() {
        for (int i = 0; i < this.dim; ++i) {
            try {
                Set<String> set = TreeUtils.getLeavesForTaxa(this.tree, this.paraphylySet.get(i));
                NodeRef nodeRef = TreeUtils.getCommonAncestorNode(this.tree, set);
                if (nodeRef == null) {
                    throw new RuntimeException("No clade found that contains " + set);
                }
                this.MRCANodeList.set(i, nodeRef);
                continue;
            }
            catch (TreeUtils.MissingTaxonException missingTaxonException) {
                throw new RuntimeException("Missing taxon!");
            }
        }
    }

    private double recurseToAccumulateRate(NodeRef nodeRef, List<NodeRef> list) {
        double d = 0.0;
        if (!this.tree.isExternal(nodeRef) && !list.contains(nodeRef)) {
            d += this.recurseToAccumulateRate(this.tree.getChild(nodeRef, 0), list);
            d += this.recurseToAccumulateRate(this.tree.getChild(nodeRef, 1), list);
        }
        if (!this.MRCANodeList.contains(nodeRef) && !this.tree.isRoot(nodeRef)) {
            d += this.branchWeighting.getBranchRate(this.branchRateModel, this.tree, nodeRef);
            this.totalTime += this.branchWeighting.getDenominator(this.tree, nodeRef);
        }
        return d;
    }

    public static enum BranchWeighting {
        NONE("none"){

            @Override
            double getBranchRate(DifferentiableBranchRates differentiableBranchRates, Tree tree, NodeRef nodeRef) {
                return differentiableBranchRates.getUntransformedBranchRate(tree, nodeRef);
            }

            @Override
            double getDenominator(Tree tree, NodeRef nodeRef) {
                return 1.0;
            }
        }
        ,
        BY_TIME("byTime"){

            @Override
            double getBranchRate(DifferentiableBranchRates differentiableBranchRates, Tree tree, NodeRef nodeRef) {
                return differentiableBranchRates.getUntransformedBranchRate(tree, nodeRef) * tree.getBranchLength(nodeRef);
            }

            @Override
            double getDenominator(Tree tree, NodeRef nodeRef) {
                return tree.getBranchLength(nodeRef);
            }
        };

        private final String name;

        private BranchWeighting(String string2) {
            this.name = string2;
        }

        public String getName() {
            return this.name;
        }

        abstract double getBranchRate(DifferentiableBranchRates var1, Tree var2, NodeRef var3);

        abstract double getDenominator(Tree var1, NodeRef var2);

        public static BranchWeighting parse(String string) {
            for (BranchWeighting branchWeighting : BranchWeighting.values()) {
                if (!branchWeighting.getName().equalsIgnoreCase(string)) continue;
                return branchWeighting;
            }
            return null;
        }
    }
}

