/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.coalescent.smooth;

import dr.evolution.coalescent.IntervalList;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Units;
import dr.evomodel.bigfasttree.BigFastTreeIntervals;
import dr.evomodel.coalescent.AbstractCoalescentLikelihood;
import dr.evomodel.coalescent.smooth.GlobalSigmoidSmoothFunction;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class SmoothSkygridLikelihood
extends AbstractCoalescentLikelihood
implements Citable,
Reportable {
    private final List<TreeModel> trees;
    private final Parameter logPopSizeParameter;
    private final Parameter gridPointParameter;
    private final Parameter smoothRate;
    private final SmoothSkygridPopulationSizeInverse populationSizeInverse;
    private final OldSmoothLineageCount lineageCount;
    private final GlobalSigmoidSmoothFunction smoothFunction;
    private final List<BigFastTreeIntervals> intervalsList;
    private double[] tmpA;
    private double[] tmpADerivOverS;
    private double[] tmpB;
    private double[] tmpBDerivOverS;
    private double[] tmpC;
    private double[] tmpCDerivOverS;
    private double[] tmpD;
    private double[] tmpE;
    private double[] tmpF;
    private double[] tmpLineageEffect;
    private double[] tmpTimes;
    private int[] tmpCounts;
    private int uniqueTimes;
    private boolean tmpSumsKnown;

    public SmoothSkygridLikelihood(String string, List<TreeModel> list, Parameter parameter, Parameter parameter2, Parameter parameter3) {
        super(string);
        this.trees = list;
        this.logPopSizeParameter = parameter;
        this.gridPointParameter = parameter2;
        this.smoothRate = parameter3;
        this.smoothFunction = new GlobalSigmoidSmoothFunction();
        this.populationSizeInverse = new SmoothSkygridPopulationSizeInverse(parameter, parameter2, this.smoothFunction, parameter3);
        this.lineageCount = new OldSmoothLineageCount(list.get(0), this.smoothFunction, parameter3);
        this.intervalsList = new ArrayList<BigFastTreeIntervals>();
        this.tmpA = new double[list.get(0).getNodeCount()];
        this.tmpB = new double[list.get(0).getNodeCount()];
        this.tmpC = new double[list.get(0).getNodeCount()];
        this.tmpADerivOverS = new double[list.get(0).getNodeCount()];
        this.tmpBDerivOverS = new double[list.get(0).getNodeCount()];
        this.tmpCDerivOverS = new double[list.get(0).getNodeCount()];
        this.tmpD = new double[parameter2.getDimension()];
        this.tmpE = new double[parameter2.getDimension()];
        this.tmpF = new double[parameter2.getDimension()];
        this.tmpLineageEffect = new double[list.get(0).getNodeCount()];
        this.tmpTimes = new double[list.get(0).getNodeCount()];
        this.tmpCounts = new int[list.get(0).getNodeCount()];
        this.tmpSumsKnown = false;
        for (int i = 0; i < list.size(); ++i) {
            this.intervalsList.add(new BigFastTreeIntervals(list.get(i)));
            this.addModel(this.intervalsList.get(i));
        }
        for (TreeModel treeModel : list) {
            this.addModel(treeModel);
        }
        this.addVariable(parameter);
        this.addVariable(parameter2);
        this.addVariable(parameter3);
    }

    @Override
    public String getReport() {
        return "smoothSkygrid(" + this.getLogLikelihood() + ")";
    }

    public Tree getTree(int n) {
        return this.trees.get(n);
    }

    @Override
    public Units.Type getUnits() {
        return null;
    }

    @Override
    public void setUnits(Units.Type type) {
    }

    public double[] getGradientWrtNodeHeight() {
        double d;
        double d2;
        double d3;
        double d4;
        int n;
        NodeRef nodeRef;
        int n2;
        assert (this.trees.size() == 1);
        Tree tree = this.trees.get(0);
        BigFastTreeIntervals bigFastTreeIntervals = this.intervalsList.get(0);
        double[] dArray = new double[tree.getInternalNodeCount()];
        int n3 = 0;
        while (n3 < tree.getInternalNodeCount()) {
            NodeRef nodeRef2 = tree.getNode(tree.getExternalNodeCount() + n3);
            int n4 = n3++;
            dArray[n4] = dArray[n4] + this.getLogSmoothPopulationSizeInverseDerivative(tree.getNodeHeight(nodeRef2), tree.getNodeHeight(tree.getRoot())) / this.getSmoothPopulationSizeInverse(tree.getNodeHeight(nodeRef2), tree.getNodeHeight(tree.getRoot()));
        }
        double d5 = tree.getNodeHeight(tree.getRoot());
        double d6 = Math.exp(-this.logPopSizeParameter.getParameterValue(0));
        double d7 = 0.0;
        for (n2 = 0; n2 < tree.getInternalNodeCount(); ++n2) {
            nodeRef = tree.getNode(tree.getExternalNodeCount() + n2);
            if (tree.isRoot(nodeRef)) continue;
            n = bigFastTreeIntervals.getIntervalIndexForNode(nodeRef.getNumber());
            d4 = this.getLineageCountDifference(n, bigFastTreeIntervals);
            int n5 = n2;
            dArray[n5] = dArray[n5] + -d6 * d4 * this.smoothFunction.getSingleIntegrationDerivative(0.0, d5, tree.getNodeHeight(nodeRef), this.smoothRate.getParameterValue(0));
        }
        for (n2 = 0; n2 < bigFastTreeIntervals.getIntervalCount(); ++n2) {
            nodeRef = tree.getNode(n2);
            n = bigFastTreeIntervals.getIntervalIndexForNode(nodeRef.getNumber());
            d4 = this.getLineageCountDifference(n, bigFastTreeIntervals);
            d7 += -d4 * this.smoothFunction.getSingleIntegrationDerivativeWrtEndTime(d5, tree.getNodeHeight(nodeRef), this.smoothRate.getParameterValue(0));
        }
        int n6 = tree.getRoot().getNumber() - tree.getExternalNodeCount();
        dArray[n6] = dArray[n6] + d6 * d7;
        n2 = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d5);
        int n7 = 0;
        while (n7 < tree.getInternalNodeCount()) {
            NodeRef nodeRef3 = tree.getNode(n7 + tree.getExternalNodeCount());
            int n8 = bigFastTreeIntervals.getIntervalIndexForNode(nodeRef3.getNumber());
            d3 = this.getLineageCountDifference(n8, bigFastTreeIntervals);
            double d8 = 0.0;
            for (int i = 0; i < n2 + 1; ++i) {
                d2 = Math.exp(-this.logPopSizeParameter.getParameterValue(i));
                d = Math.exp(-this.logPopSizeParameter.getParameterValue(i + 1));
                double d9 = this.gridPointParameter.getParameterValue(i);
                d8 += (d - d2) * this.smoothFunction.getPairProductIntegrationDerivative(0.0, d5, tree.getNodeHeight(nodeRef3), d9, this.smoothRate.getParameterValue(0));
            }
            int n9 = n7++;
            dArray[n9] = dArray[n9] - d3 * d8;
        }
        d7 = 0.0;
        for (n7 = 0; n7 < bigFastTreeIntervals.getIntervalCount() + 1; ++n7) {
            double d10 = this.getLineageCountDifference(n7, bigFastTreeIntervals);
            d3 = bigFastTreeIntervals.getIntervalTime(n7);
            for (int i = 0; i < n2 + 1; ++i) {
                double d11 = Math.exp(-this.logPopSizeParameter.getParameterValue(i));
                d2 = Math.exp(-this.logPopSizeParameter.getParameterValue(i + 1));
                d = this.gridPointParameter.getParameterValue(i);
                d7 -= d10 * (d2 - d11) * this.smoothFunction.getPairProductIntegrationDerivativeWrtEndTime(0.0, d5, d3, d, this.smoothRate.getParameterValue(0));
            }
        }
        int n10 = tree.getRoot().getNumber() - tree.getExternalNodeCount();
        dArray[n10] = dArray[n10] + d7;
        return dArray;
    }

    private double getLineageCountDifference(int n, BigFastTreeIntervals bigFastTreeIntervals) {
        if (n == 0) {
            return (double)bigFastTreeIntervals.getLineageCount(0) * (double)(bigFastTreeIntervals.getLineageCount(0) - 1) / 2.0;
        }
        if (n == bigFastTreeIntervals.getIntervalCount()) {
            return -((double)bigFastTreeIntervals.getLineageCount(n - 1) * (double)(bigFastTreeIntervals.getLineageCount(n - 1) - 1)) / 2.0;
        }
        return ((double)bigFastTreeIntervals.getLineageCount(n) * (double)(bigFastTreeIntervals.getLineageCount(n) - 1) - (double)(bigFastTreeIntervals.getLineageCount(n - 1) * (bigFastTreeIntervals.getLineageCount(n - 1) - 1))) / 2.0;
    }

    private void calculateTmpSums() {
        if (!this.tmpSumsKnown) {
            double d;
            double d2;
            double d3;
            double d4;
            int n;
            double d5;
            int n2;
            TreeModel treeModel = this.trees.get(0);
            double d6 = treeModel.getNodeHeight(treeModel.getRoot());
            int n3 = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d6);
            NodeRef[] nodeRefArray = new NodeRef[treeModel.getNodeCount()];
            System.arraycopy(treeModel.getNodes(), 0, nodeRefArray, 0, treeModel.getNodeCount());
            Arrays.parallelSort(nodeRefArray, (nodeRef, nodeRef2) -> Double.compare(treeModel.getNodeHeight((NodeRef)nodeRef), treeModel.getNodeHeight((NodeRef)nodeRef2)));
            double d7 = treeModel.getNodeHeight(nodeRefArray[0]);
            double d8 = this.getLineageCountEffect(treeModel, 0);
            int n4 = 1;
            int n5 = 0;
            this.tmpTimes[n5] = d7;
            for (n2 = 1; n2 < nodeRefArray.length; ++n2) {
                NodeRef nodeRef3 = nodeRefArray[n2];
                double d9 = treeModel.getNodeHeight(nodeRef3);
                if (d9 == d7) {
                    ++n4;
                    d8 += this.getLineageCountEffect(treeModel, nodeRef3.getNumber());
                    continue;
                }
                this.tmpLineageEffect[n5] = d8;
                this.tmpCounts[n5] = n4;
                this.tmpTimes[++n5] = d9;
                n4 = 1;
                d8 = this.getLineageCountEffect(treeModel, nodeRef3.getNumber());
                d7 = d9;
            }
            this.tmpLineageEffect[n5] = d8;
            this.tmpCounts[n5] = n4;
            this.uniqueTimes = n5 + 1;
            for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
                double d10 = this.tmpTimes[n2];
                d5 = 0.0;
                for (n = 0; n < this.uniqueTimes; ++n) {
                    if (n == n2) continue;
                    d4 = this.tmpTimes[n];
                    d3 = this.tmpLineageEffect[n];
                    d2 = this.smoothFunction.getInverseOneMinusExponential(d4 - d10, this.smoothRate.getParameterValue(0));
                    d5 += d3 * d2;
                }
                this.tmpA[n2] = d5;
            }
            for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
                double d11 = this.tmpTimes[n2];
                d5 = 0.0;
                for (n = 0; n < n3; ++n) {
                    d4 = Math.exp(-this.logPopSizeParameter.getParameterValue(n));
                    d3 = Math.exp(-this.logPopSizeParameter.getParameterValue(n + 1));
                    d2 = this.gridPointParameter.getParameterValue(n);
                    d = this.smoothFunction.getInverseOneMinusExponential(d2 - d11, this.smoothRate.getParameterValue(0));
                    d5 += (d3 - d4) * d;
                }
                this.tmpB[n2] = d5;
            }
            for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
                double d12 = this.tmpTimes[n2];
                this.tmpC[n2] = d5 = this.smoothFunction.getLogOnePlusExponential(d12 - d6, this.smoothRate.getParameterValue(0)) - this.smoothFunction.getLogOnePlusExponential(d12 - 0.0, this.smoothRate.getParameterValue(0));
            }
            for (n2 = 0; n2 < n3; ++n2) {
                double d13 = this.gridPointParameter.getParameterValue(n2);
                this.tmpD[n2] = this.smoothFunction.getLogOnePlusExponential(d13 - d6, this.smoothRate.getParameterValue(0)) - this.smoothFunction.getLogOnePlusExponential(d13 - 0.0, this.smoothRate.getParameterValue(0));
                d5 = 0.0;
                double d14 = 0.0;
                for (int i = 0; i < this.uniqueTimes; ++i) {
                    d3 = this.tmpTimes[i];
                    d2 = this.tmpLineageEffect[i];
                    d = this.smoothFunction.getInverseOneMinusExponential(d3 - d13, this.smoothRate.getParameterValue(0)) * d2;
                    d5 += d;
                    d14 += d * d;
                }
                this.tmpE[n2] = d5;
                this.tmpF[n2] = d5 * d5 - d14;
            }
            this.tmpSumsKnown = true;
        }
    }

    private void calculateTmpSumDerivatives() {
        double d;
        double d2;
        double d3;
        int n;
        double d4;
        double d5;
        int n2;
        if (!this.tmpSumsKnown) {
            this.calculateTmpSums();
        }
        TreeModel treeModel = this.trees.get(0);
        double d6 = treeModel.getNodeHeight(treeModel.getRoot());
        int n3 = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d6);
        for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
            d5 = this.tmpTimes[n2];
            d4 = 0.0;
            for (n = 0; n < this.uniqueTimes; ++n) {
                if (n == n2) continue;
                d3 = this.tmpTimes[n];
                d2 = this.tmpLineageEffect[n];
                d = this.smoothFunction.getInverseOneMinusExponential(d3 - d5, this.smoothRate.getParameterValue(0));
                d4 += d2 * d * (1.0 - d);
            }
            this.tmpADerivOverS[n2] = -d4;
        }
        for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
            d5 = this.tmpTimes[n2];
            d4 = 0.0;
            for (n = 0; n < n3; ++n) {
                d3 = Math.exp(-this.logPopSizeParameter.getParameterValue(n));
                d2 = Math.exp(-this.logPopSizeParameter.getParameterValue(n + 1));
                d = this.gridPointParameter.getParameterValue(n);
                double d7 = this.smoothFunction.getInverseOneMinusExponential(d - d5, this.smoothRate.getParameterValue(0));
                d4 += (d2 - d3) * d7 * (1.0 - d7);
            }
            this.tmpBDerivOverS[n2] = -d4;
        }
        for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
            d5 = this.tmpTimes[n2];
            this.tmpCDerivOverS[n2] = this.smoothFunction.getSingleIntegrationDerivative(0.0, d6, d5, this.smoothRate.getParameterValue(0));
        }
    }

    @Override
    protected double calculateLogLikelihood() {
        assert (this.trees.size() == 1);
        if (!this.likelihoodKnown) {
            TreeModel treeModel = this.trees.get(0);
            double d = treeModel.getNodeHeight(treeModel.getRoot());
            int n = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d);
            this.calculateTmpSums();
            double d2 = 0.0;
            for (int i = 0; i < this.uniqueTimes; ++i) {
                d2 += this.tmpLineageEffect[i] * this.tmpLineageEffect[i];
            }
            double d3 = this.getTripleIntegration(0.0, d, n, d2);
            double d4 = this.getDoubleIntegration(0.0, d, n, d2);
            double d5 = this.getSingleIntegration(0.0, d);
            double d6 = 0.0;
            for (int i = 0; i < treeModel.getInternalNodeCount(); ++i) {
                NodeRef nodeRef = treeModel.getNode(treeModel.getExternalNodeCount() + i);
                d6 += Math.log(this.getSmoothPopulationSizeInverse(treeModel.getNodeHeight(nodeRef), treeModel.getNodeHeight(treeModel.getRoot())));
            }
            this.logLikelihood = d6 + d5 + d4 + d3;
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override
    protected void restoreState() {
        super.restoreState();
        this.tmpSumsKnown = false;
    }

    private double[] getGradientWrtNodeHeightNew() {
        if (!this.likelihoodKnown) {
            this.calculateLogLikelihood();
        }
        TreeModel treeModel = this.trees.get(0);
        double d = treeModel.getNodeHeight(treeModel.getRoot());
        double[] dArray = new double[treeModel.getInternalNodeCount()];
        this.getGradientWrtNodeHeightFromSingleIntegration(0.0, d, dArray);
        double d2 = 0.0;
        for (int i = 0; i < this.uniqueTimes; ++i) {
            d2 += this.tmpLineageEffect[i] * this.tmpLineageEffect[i];
        }
        this.getGradientWrtNodeHeightFromDoubleIntegration(0.0, d, SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d), dArray);
        this.getGradientWrtNodeHeightFromTripleIntegration(0.0, d, SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d), dArray);
        return dArray;
    }

    double getTripleIntegration(double d, double d2, int n, double d3) {
        double d4;
        int n2;
        double d5 = 0.0;
        for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
            d4 = this.tmpLineageEffect[n2];
            d5 += d4 * this.tmpA[n2] * this.tmpB[n2] * this.tmpC[n2];
        }
        d5 *= 2.0;
        for (n2 = 0; n2 < n; ++n2) {
            d4 = Math.exp(-this.logPopSizeParameter.getParameterValue(n2));
            double d6 = Math.exp(-this.logPopSizeParameter.getParameterValue(n2 + 1));
            d5 += (d6 - d4) * this.tmpF[n2] * this.tmpD[n2];
        }
        d5 /= -this.smoothRate.getParameterValue(0) * 2.0;
        d5 += -0.5 * (1.0 - d3) * (Math.exp(-this.logPopSizeParameter.getParameterValue(n)) - Math.exp(-this.logPopSizeParameter.getParameterValue(0))) * (d2 - d);
        double d7 = 0.0;
        double d8 = (Math.exp(-this.logPopSizeParameter.getParameterValue(n)) - Math.exp(-this.logPopSizeParameter.getParameterValue(0))) * (d2 - d);
        for (int i = 0; i < this.uniqueTimes; ++i) {
            double d9 = this.tmpLineageEffect[i] * this.tmpLineageEffect[i];
            double d10 = this.tmpTimes[i];
            double d11 = d8;
            double d12 = this.smoothFunction.getInverseOnePlusExponential(d10 - d, this.smoothRate.getParameterValue(0)) - this.smoothFunction.getInverseOnePlusExponential(d10 - d2, this.smoothRate.getParameterValue(0));
            for (int j = 0; j < n; ++j) {
                double d13 = Math.exp(-this.logPopSizeParameter.getParameterValue(j));
                double d14 = Math.exp(-this.logPopSizeParameter.getParameterValue(j + 1));
                double d15 = this.gridPointParameter.getParameterValue(j);
                double d16 = this.smoothFunction.getInverseOneMinusExponential(d15 - d10, this.smoothRate.getParameterValue(0));
                d11 += (d14 - d13) / this.smoothRate.getParameterValue(0) * (d16 * d12 + (2.0 - d16) * d16 * this.tmpC[i] + (1.0 - d16) * (1.0 - d16) * this.tmpD[j]);
            }
            d7 += (d11 *= d9);
        }
        return d5 + (d7 *= -0.5);
    }

    private void getGradientWrtNodeHeightFromTripleIntegration(double d, double d2, int n, double[] dArray) {
        for (int i = 0; i < this.uniqueTimes; ++i) {
            double d3 = this.tmpLineageEffect[i];
            double d4 = this.tmpTimes[i];
            int n2 = i;
            dArray[n2] = dArray[n2] + d3 * (this.tmpADerivOverS[i] * this.tmpB[i] * this.tmpC[i] + this.tmpA[i] * this.tmpBDerivOverS[i] * this.tmpC[i] + this.tmpA[i] * this.tmpB[i] * this.tmpCDerivOverS[i]);
            for (int j = 0; j < n; ++j) {
                double d5 = Math.exp(-this.logPopSizeParameter.getParameterValue(j));
                double d6 = Math.exp(-this.logPopSizeParameter.getParameterValue(j + 1));
                double d7 = this.gridPointParameter.getParameterValue(j);
                double d8 = this.smoothFunction.getInverseOneMinusExponential(d4 - d7, this.smoothRate.getParameterValue(0));
                int n3 = i;
                dArray[n3] = dArray[n3] + (d6 - d5) * this.tmpD[j] * (this.tmpE[j] - d3 * d8) * d8 * (1.0 - d8) * d3;
            }
            double d9 = this.smoothFunction.getInverseOnePlusExponential(d4 - d, this.smoothRate.getParameterValue(0));
            double d10 = this.smoothFunction.getInverseOnePlusExponential(d4 - d, this.smoothRate.getParameterValue(0));
            double d11 = d9 - d10;
            double d12 = -d9 * (1.0 - d9) + d10 * (1.0 - d10);
            for (int j = 0; j < n; ++j) {
                double d13 = Math.exp(-this.logPopSizeParameter.getParameterValue(j));
                double d14 = Math.exp(-this.logPopSizeParameter.getParameterValue(j + 1));
                double d15 = this.gridPointParameter.getParameterValue(j);
                double d16 = this.smoothFunction.getInverseOneMinusExponential(d15 - d4, this.smoothRate.getParameterValue(0));
                double d17 = -d16 * (1.0 - d16);
                int n4 = i;
                dArray[n4] = dArray[n4] + (d14 - d13) * (d17 * d11 + d16 * d12 + 2.0 * (1.0 - d16) * d17 * this.tmpC[i] + (2.0 - d16) * d16 * this.tmpCDerivOverS[i] + 2.0 * (1.0 - d16) * -d17 * this.tmpD[j]);
            }
        }
    }

    double getDoubleIntegration(double d, double d2, int n, double d3) {
        int n2;
        double d4 = 0.0;
        double d5 = 0.0;
        for (int i = 0; i < this.uniqueTimes; ++i) {
            double d6 = this.tmpLineageEffect[i];
            double d7 = this.tmpTimes[i];
            d4 += d6 * this.tmpA[i] * this.tmpC[i];
            d5 += d6 * d6 * this.smoothFunction.getQuadraticIntegration(d, d2, d7, this.smoothRate.getParameterValue(0));
        }
        d4 /= this.smoothRate.getParameterValue(0);
        double d8 = -(d5 * 0.5 + (d4 += 0.5 * (1.0 - d3) * (d2 - d))) * Math.exp(-this.logPopSizeParameter.getParameterValue(0));
        double d9 = 0.0;
        for (n2 = 0; n2 < this.uniqueTimes; ++n2) {
            d9 += 0.5 * this.tmpB[n2] * this.tmpC[n2] * this.tmpLineageEffect[n2];
        }
        for (n2 = 0; n2 < n; ++n2) {
            double d10 = Math.exp(-this.logPopSizeParameter.getParameterValue(n2));
            double d11 = Math.exp(-this.logPopSizeParameter.getParameterValue(n2 + 1));
            d9 += 0.5 * this.tmpE[n2] * this.tmpD[n2] * (d11 - d10);
        }
        d9 /= this.smoothRate.getParameterValue(0);
        return d8 + (d9 += 0.5 * (d2 - d) * (Math.exp(-this.logPopSizeParameter.getParameterValue(n)) - Math.exp(-this.logPopSizeParameter.getParameterValue(0))));
    }

    private void getGradientWrtNodeHeightFromDoubleIntegration(double d, double d2, int n, double[] dArray) {
        double d3 = Math.exp(-this.logPopSizeParameter.getParameterValue(0));
        for (int i = 0; i < this.uniqueTimes; ++i) {
            double d4 = this.tmpLineageEffect[i];
            double d5 = this.tmpTimes[i];
            int n2 = i;
            dArray[n2] = dArray[n2] + -d4 * (this.tmpA[i] * this.tmpCDerivOverS[i] + this.tmpADerivOverS[i] * this.tmpC[i]) * d3;
            int n3 = i;
            dArray[n3] = dArray[n3] + d4 * d4 * (this.smoothFunction.getSingleIntegrationDerivative(d, d2, d5, this.smoothRate.getParameterValue(0)) + (this.smoothFunction.getDerivative(d5, d2, 0.0, 1.0, this.smoothRate.getParameterValue(0)) - this.smoothFunction.getDerivative(d5, d, 0.0, 1.0, this.smoothRate.getParameterValue(0)) / this.smoothRate.getParameterValue(0))) * -0.5 * d3;
            int n4 = i;
            dArray[n4] = dArray[n4] + 0.5 * this.tmpLineageEffect[i] * (this.tmpB[i] * this.tmpCDerivOverS[i] + this.tmpBDerivOverS[i] * this.tmpC[i]);
            for (int j = 0; j < n; ++j) {
                double d6 = Math.exp(-this.logPopSizeParameter.getParameterValue(j));
                double d7 = Math.exp(-this.logPopSizeParameter.getParameterValue(j + 1));
                double d8 = this.gridPointParameter.getParameterValue(j);
                double d9 = this.smoothFunction.getInverseOneMinusExponential(d5 - d8, this.smoothRate.getParameterValue(0));
                int n5 = i;
                dArray[n5] = dArray[n5] + 0.5 * this.tmpD[j] * (d7 - d6) * d9 * (1.0 - d9) * d4;
            }
        }
    }

    private double getSingleIntegration(double d, double d2) {
        double d3 = 0.0;
        for (int i = 0; i < this.uniqueTimes; ++i) {
            double d4 = this.tmpTimes[i];
            double d5 = this.tmpLineageEffect[i];
            d3 += d5 * this.smoothFunction.getSingleIntegration(d, d2, d4, this.smoothRate.getParameterValue(0));
        }
        return d3 *= 0.5 * Math.exp(-this.logPopSizeParameter.getParameterValue(0));
    }

    private void getGradientWrtNodeHeightFromSingleIntegration(double d, double d2, double[] dArray) {
        int n = 0;
        while (n < this.uniqueTimes) {
            double d3 = this.tmpTimes[n];
            double d4 = this.tmpLineageEffect[n];
            int n2 = n++;
            dArray[n2] = dArray[n2] + d4 * this.smoothFunction.getSingleIntegrationDerivative(d, d2, d3, this.smoothRate.getParameterValue(0)) * 0.5 * Math.exp(-this.logPopSizeParameter.getParameterValue(0));
        }
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        super.handleModelChangedEvent(model, object, n);
        this.tmpSumsKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.tmpSumsKnown = false;
        this.likelihoodKnown = false;
    }

    private double getLineageCountEffect(Tree tree, int n) {
        if (tree.isExternal(tree.getNode(n))) {
            return 1.0;
        }
        return -1.0;
    }

    protected double calculateLogLikelihood2() {
        double d;
        double d2;
        int n;
        double d3;
        int n2;
        assert (this.trees.size() == 1);
        if (!this.intervalsKnown) {
            for (IntervalList intervalList2 : this.intervalsList) {
                intervalList2.calculateIntervals();
            }
            this.intervalsKnown = true;
        }
        Tree tree = this.trees.get(0);
        BigFastTreeIntervals bigFastTreeIntervals = this.intervalsList.get(0);
        double d4 = 0.0;
        for (int d5 = 0; d5 < tree.getInternalNodeCount(); ++d5) {
            NodeRef nodeRef = tree.getNode(tree.getExternalNodeCount() + d5);
            d4 += Math.log(this.getSmoothPopulationSizeInverse(tree.getNodeHeight(nodeRef), tree.getNodeHeight(tree.getRoot())));
        }
        double d5 = tree.getNodeHeight(tree.getRoot());
        double d6 = Math.exp(-this.logPopSizeParameter.getParameterValue(0));
        double d7 = this.getLineageCountDifference(0, bigFastTreeIntervals) * this.smoothFunction.getSingleIntegration(0.0, d5, bigFastTreeIntervals.getIntervalTime(0), this.smoothRate.getParameterValue(0));
        for (n2 = 1; n2 < bigFastTreeIntervals.getIntervalCount() + 1; ++n2) {
            d3 = this.getLineageCountDifference(n2, bigFastTreeIntervals);
            double n3 = this.smoothFunction.getSingleIntegration(0.0, d5, bigFastTreeIntervals.getIntervalTime(n2), this.smoothRate.getParameterValue(0));
            d7 += d3 * n3;
        }
        d7 *= d6;
        n2 = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d5);
        d3 = 0.0;
        for (n = 0; n < n2 + 1; ++n) {
            d2 = Math.exp(-this.logPopSizeParameter.getParameterValue(n));
            d = Math.exp(-this.logPopSizeParameter.getParameterValue(n + 1));
            double i = this.gridPointParameter.getParameterValue(n);
            d3 += this.getLineageCountDifference(0, bigFastTreeIntervals) * (d - d2) * this.smoothFunction.getPairProductIntegration(0.0, d5, bigFastTreeIntervals.getIntervalTime(0), i, this.smoothRate.getParameterValue(0));
        }
        for (n = 1; n < bigFastTreeIntervals.getIntervalCount() + 1; ++n) {
            d2 = this.getLineageCountDifference(n, bigFastTreeIntervals);
            d = bigFastTreeIntervals.getIntervalTime(n);
            for (int i = 0; i < n2 + 1; ++i) {
                double d8 = Math.exp(-this.logPopSizeParameter.getParameterValue(i));
                double d9 = Math.exp(-this.logPopSizeParameter.getParameterValue(i + 1));
                double d10 = this.gridPointParameter.getParameterValue(i);
                d3 += d2 * (d9 - d8) * this.smoothFunction.getPairProductIntegration(0.0, d5, d, d10, this.smoothRate.getParameterValue(0));
            }
        }
        return d4 - d7 - d3;
    }

    public static int getMaxGridIndex(Parameter parameter, double d) {
        int n;
        for (n = parameter.getDimension() - 1; parameter.getParameterValue(n) > d && n > 0; --n) {
        }
        return n + 1;
    }

    private double getSmoothPopulationSizeInverse(double d, double d2) {
        int n = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d2);
        double d3 = Math.exp(-this.logPopSizeParameter.getParameterValue(0));
        for (int i = 0; i < n + 1; ++i) {
            double d4 = Math.exp(-this.logPopSizeParameter.getParameterValue(i));
            double d5 = Math.exp(-this.logPopSizeParameter.getParameterValue(i + 1));
            double d6 = this.gridPointParameter.getParameterValue(i);
            d3 += (d5 - d4) * this.smoothFunction.getSmoothValue(d, d6, 0.0, 1.0, this.smoothRate.getParameterValue(0));
        }
        return d3;
    }

    private double getLogSmoothPopulationSizeInverseDerivative(double d, double d2) {
        int n = SmoothSkygridLikelihood.getMaxGridIndex(this.gridPointParameter, d2);
        double d3 = 0.0;
        for (int i = 0; i < n + 1; ++i) {
            double d4 = Math.exp(-this.logPopSizeParameter.getParameterValue(i));
            double d5 = Math.exp(-this.logPopSizeParameter.getParameterValue(i + 1));
            double d6 = this.gridPointParameter.getParameterValue(i);
            d3 += (d5 - d4) * this.smoothFunction.getDerivative(d, d6, 0.0, 1.0, this.smoothRate.getParameterValue(0));
        }
        return d3;
    }

    @Override
    public int getNumberOfCoalescentEvents() {
        int n = 0;
        for (Tree tree : this.trees) {
            n += tree.getInternalNodeCount();
        }
        return n;
    }

    @Override
    public double getCoalescentEventsStatisticValue(int n) {
        throw new RuntimeException("Not yet implemented.");
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    @Override
    public String getDescription() {
        return "Differentiable skygrid coalescent";
    }

    @Override
    public List<Citation> getCitations() {
        return Arrays.asList(CommonCitations.GILL_2013_IMPROVING, new Citation(new Author[]{new Author("Y", "Bao"), new Author("MA", "Suchard"), new Author("X", "Ji")}, Citation.Status.IN_PREPARATION));
    }

    class SmoothSkygridPopulationSizeInverse {
        private final Parameter logPopSizeParameter;
        private final Parameter gridPointParameter;
        private final GlobalSigmoidSmoothFunction smoothFunction;
        private final Parameter smoothRate;

        SmoothSkygridPopulationSizeInverse(Parameter parameter, Parameter parameter2, GlobalSigmoidSmoothFunction globalSigmoidSmoothFunction, Parameter parameter3) {
            this.logPopSizeParameter = parameter;
            this.gridPointParameter = parameter2;
            this.smoothRate = parameter3;
            this.smoothFunction = globalSigmoidSmoothFunction;
        }

        double getPopulationSizeInverse(double d) {
            double d2 = 0.0;
            for (int i = 0; i < this.gridPointParameter.getDimension(); ++i) {
                double d3 = this.smoothFunction.getSmoothValue(d, this.gridPointParameter.getParameterValue(i), i == 0 ? Math.exp(-this.logPopSizeParameter.getParameterValue(0)) : 0.0, i == 0 ? Math.exp(-this.logPopSizeParameter.getParameterValue(1)) : Math.exp(-this.logPopSizeParameter.getParameterValue(i + 1)) - Math.exp(-this.logPopSizeParameter.getParameterValue(i)), this.smoothRate.getParameterValue(0));
                d2 += d3;
            }
            return d2;
        }
    }

    class OldSmoothLineageCount {
        private final Tree tree;
        private final GlobalSigmoidSmoothFunction smoothFunction;
        private final Parameter smoothRate;

        OldSmoothLineageCount(Tree tree, GlobalSigmoidSmoothFunction globalSigmoidSmoothFunction, Parameter parameter) {
            this.tree = tree;
            this.smoothFunction = globalSigmoidSmoothFunction;
            this.smoothRate = parameter;
        }

        double getLineageCount(double d) {
            int n;
            double d2 = 0.0;
            for (n = 0; n < this.tree.getExternalNodeCount(); ++n) {
                d2 += this.smoothFunction.getSmoothValue(d, this.tree.getNodeHeight(this.tree.getNode(n)), 0.0, 1.0, this.smoothRate.getParameterValue(0));
            }
            for (n = this.tree.getExternalNodeCount(); n < this.tree.getNodeCount(); ++n) {
                d2 += this.smoothFunction.getSmoothValue(d, this.tree.getNodeHeight(this.tree.getNode(n)), 0.0, -1.0, this.smoothRate.getParameterValue(0));
            }
            return d2;
        }
    }
}

