/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.discrete;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.BeagleDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.preorder.AbstractBeagleGradientDelegate;
import java.util.Arrays;

public class SubstitutionModelCrossProductDelegate
extends AbstractBeagleGradientDelegate {
    private final String name;
    private final Tree tree;
    private final BranchRateModel branchRateModel;
    private final BranchModel branchModel;
    private final int stateCount;
    private final int substitutionModelCount;
    private static final String GRADIENT_TRAIT_NAME = "substitutionModelCrossProductGradient";

    public SubstitutionModelCrossProductDelegate(String string, Tree tree, BeagleDataLikelihoodDelegate beagleDataLikelihoodDelegate, BranchRateModel branchRateModel, int n) {
        super(string, tree, beagleDataLikelihoodDelegate);
        this.name = string;
        this.tree = tree;
        this.stateCount = n;
        this.branchRateModel = branchRateModel;
        this.branchModel = beagleDataLikelihoodDelegate.getBranchModel();
        this.substitutionModelCount = this.branchModel.getSubstitutionModels().size();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private double getBranchLength(NodeRef nodeRef) {
        double d;
        BranchRateModel branchRateModel = this.branchRateModel;
        synchronized (branchRateModel) {
            d = this.branchRateModel.getBranchRate(this.tree, nodeRef);
        }
        double d2 = this.tree.getNodeHeight(this.tree.getParent(nodeRef));
        double d3 = this.tree.getNodeHeight(nodeRef);
        return d * (d2 - d3);
    }

    @Override
    protected int getGradientLength() {
        return this.stateCount * this.stateCount * this.substitutionModelCount;
    }

    private int coverWholeTree(int[] nArray, int[] nArray2, double[] dArray) {
        int n = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(this.tree.getNode(i))) continue;
            nArray[n] = this.getPostOrderPartialIndex(i);
            nArray2[n] = this.getPreOrderPartialIndex(i);
            dArray[n] = this.getBranchLength(nodeRef);
            ++n;
        }
        return this.tree.getNodeCount() - 1;
    }

    private int coverPartialTree(int n, int[] nArray, int[] nArray2, double[] dArray) {
        int n2 = 0;
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(this.tree.getNode(i))) continue;
            BranchModel.Mapping mapping = this.branchModel.getBranchModelMapping(nodeRef);
            int[] nArray3 = mapping.getOrder();
            double[] dArray2 = mapping.getWeights();
            for (int j = 0; j < nArray3.length; ++j) {
                if (nArray3[j] != n) continue;
                nArray[n2] = this.getPostOrderPartialIndex(i);
                nArray2[n2] = this.getPreOrderPartialIndex(i);
                dArray[n2] = this.getBranchLength(nodeRef) * this.relativeWeight(j, dArray2);
                ++n2;
            }
        }
        return n2;
    }

    private double relativeWeight(int n, double[] dArray) {
        double d = 0.0;
        for (double d2 : dArray) {
            d += d2;
        }
        return dArray[n] / d;
    }

    @Override
    protected void getNodeDerivatives(Tree tree, double[] dArray, double[] dArray2) {
        int n = this.stateCount * this.stateCount;
        assert (dArray.length >= n * this.substitutionModelCount);
        assert (dArray2 == null || dArray2.length >= this.stateCount * this.stateCount * this.substitutionModelCount);
        if (dArray2 != null) {
            throw new RuntimeException("Not yet implemented");
        }
        int[] nArray = new int[tree.getNodeCount() - 1];
        int[] nArray2 = new int[tree.getNodeCount() - 1];
        double[] dArray3 = new double[tree.getNodeCount() - 1];
        if (this.substitutionModelCount == 1) {
            Arrays.fill(dArray, 0, dArray.length, 0.0);
            int n2 = this.coverWholeTree(nArray, nArray2, dArray3);
            this.beagle.calculateCrossProductDifferentials(nArray, nArray2, new int[]{0}, new int[]{0}, dArray3, n2, dArray, null);
        } else {
            double[] dArray4 = new double[n];
            for (int i = 0; i < this.substitutionModelCount; ++i) {
                Arrays.fill(dArray4, 0, dArray4.length, 0.0);
                int n3 = this.coverPartialTree(i, nArray, nArray2, dArray3);
                this.beagle.calculateCrossProductDifferentials(nArray, nArray2, new int[]{0}, new int[]{0}, dArray3, n3, dArray4, null);
                System.arraycopy(dArray4, 0, dArray, i * n, n);
            }
        }
    }

    @Override
    protected String getGradientTraitName() {
        return "substitutionModelCrossProductGradient." + this.name;
    }

    public static String getName(String string) {
        return "substitutionModelCrossProductGradient." + string;
    }

    @Override
    protected void constructTraits(TreeTraitProvider.Helper helper) {
        helper.addTrait(new TreeTrait.DA(){

            @Override
            public String getTraitName() {
                return SubstitutionModelCrossProductDelegate.this.getGradientTraitName();
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.WHOLE_TREE;
            }

            @Override
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return SubstitutionModelCrossProductDelegate.this.getGradient(nodeRef);
            }
        });
    }
}

