/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.operators;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.util.Taxon;
import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood;
import dr.evomodel.continuous.SampledMultivariateTraitLikelihood;
import dr.geo.GeoSpatialCollectionModel;
import dr.geo.GeoSpatialDistribution;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

public class TraitGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator {
    public static final String GIBBS_OPERATOR = "traitGibbsOperator";
    public static final String INTERNAL_ONLY = "onlyInternalNodes";
    public static final String TIP_WITH_PRIORS_ONLY = "onlyTipsWithPriors";
    public static final String NODE_PRIOR = "nodePrior";
    public static final String NODE_LABEL = "taxon";
    public static final String ROOT_PRIOR = "rootPrior";
    private final MutableTreeModel treeModel;
    private final MatrixParameter precisionMatrixParameter;
    private final SampledMultivariateTraitLikelihood traitModel;
    private final int dim;
    private final String traitName;
    private Map<Taxon, GeoSpatialDistribution> nodeGeoSpatialPrior;
    private Map<Taxon, MultivariateNormalDistribution> nodeMVNPrior;
    private GeoSpatialCollectionModel parameterPrior = null;
    private boolean onlyInternalNodes = true;
    private boolean onlyTipsWithPriors = true;
    private boolean sampleRoot = false;
    private double[] rootPriorMean;
    private double[][] rootPriorPrecision;
    private final int maxTries = 10000;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final String[] names = new String[]{"traitGibbsOperator", "internalTraitGibbsOperator"};
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("weight"), AttributeRule.newBooleanRule("onlyInternalNodes", true), AttributeRule.newBooleanRule("onlyTipsWithPriors", true), new ElementRule(SampledMultivariateTraitLikelihood.class), new ElementRule(MultivariateDistributionLikelihood.class, 0, Integer.MAX_VALUE), new ElementRule("rootPrior", new XMLSyntaxRule[]{new ElementRule(MultivariateDistributionLikelihood.class)}, true), new ElementRule(GeoSpatialCollectionModel.class, true)};

        @Override
        public String getParserName() {
            return TraitGibbsOperator.GIBBS_OPERATOR;
        }

        @Override
        public String[] getParserNames() {
            return this.names;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double d = xMLObject.getDoubleAttribute("weight");
            boolean bl = xMLObject.getAttribute(TraitGibbsOperator.INTERNAL_ONLY, true);
            boolean bl2 = xMLObject.getAttribute(TraitGibbsOperator.TIP_WITH_PRIORS_ONLY, true);
            SampledMultivariateTraitLikelihood sampledMultivariateTraitLikelihood = (SampledMultivariateTraitLikelihood)xMLObject.getChild(AbstractMultivariateTraitLikelihood.class);
            TraitGibbsOperator traitGibbsOperator = new TraitGibbsOperator(sampledMultivariateTraitLikelihood, bl, bl2);
            traitGibbsOperator.setWeight(d);
            XMLObject xMLObject2 = xMLObject.getChild(TraitGibbsOperator.ROOT_PRIOR);
            if (xMLObject2 != null) {
                MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood)xMLObject2.getChild(MultivariateDistributionLikelihood.class);
                if (!(multivariateDistributionLikelihood.getDistribution() instanceof MultivariateDistribution)) {
                    throw new XMLParseException("Only multivariate normal priors allowed for Gibbs sampling the root trait");
                }
                traitGibbsOperator.setRootPrior((MultivariateNormalDistribution)multivariateDistributionLikelihood.getDistribution());
            }
            for (int i = 0; i < xMLObject.getChildCount(); ++i) {
                MultivariateDistribution multivariateDistribution;
                if (!(xMLObject.getChild(i) instanceof MultivariateDistributionLikelihood) || !((multivariateDistribution = ((MultivariateDistributionLikelihood)xMLObject.getChild(i)).getDistribution()) instanceof GeoSpatialDistribution)) continue;
                GeoSpatialDistribution geoSpatialDistribution = (GeoSpatialDistribution)multivariateDistribution;
                String string = geoSpatialDistribution.getLabel();
                Taxon taxon = this.getTaxon(sampledMultivariateTraitLikelihood.getTreeModel(), string);
                traitGibbsOperator.setTaxonPrior(taxon, geoSpatialDistribution);
                System.err.println("Adding truncated prior for taxon '" + taxon + "'");
            }
            GeoSpatialCollectionModel geoSpatialCollectionModel = (GeoSpatialCollectionModel)xMLObject.getChild(GeoSpatialCollectionModel.class);
            if (geoSpatialCollectionModel != null) {
                traitGibbsOperator.setParameterPrior(geoSpatialCollectionModel);
                System.err.println("Adding truncated prior '" + geoSpatialCollectionModel.getId() + "' for parameter '" + geoSpatialCollectionModel.getParameter().getId() + "'");
            }
            return traitGibbsOperator;
        }

        private Taxon getTaxon(MutableTreeModel mutableTreeModel, String string) throws XMLParseException {
            int n = mutableTreeModel.getTaxonIndex(string);
            if (n == -1) {
                throw new XMLParseException("Taxon '" + string + "' not found for geoSpatialDistribution element in traitGibbsOperator element");
            }
            return mutableTreeModel.getTaxon(n);
        }

        @Override
        public String getParserDescription() {
            return "This element returns a multivariate Gibbs operator on traits for possible all nodes.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public TraitGibbsOperator(SampledMultivariateTraitLikelihood sampledMultivariateTraitLikelihood, boolean bl, boolean bl2) {
        this.traitModel = sampledMultivariateTraitLikelihood;
        this.treeModel = sampledMultivariateTraitLikelihood.getTreeModel();
        this.precisionMatrixParameter = (MatrixParameter)sampledMultivariateTraitLikelihood.getDiffusionModel().getPrecisionParameter();
        this.traitName = sampledMultivariateTraitLikelihood.getTraitName();
        this.onlyInternalNodes = bl;
        this.onlyTipsWithPriors = bl2;
        this.dim = this.treeModel.getMultivariateNodeTrait(this.treeModel.getRoot(), this.traitName).length;
        Logger.getLogger("dr.evomodel").info("Using *NEW* trait Gibbs operator");
    }

    public void setRootPrior(MultivariateNormalDistribution multivariateNormalDistribution) {
        this.rootPriorMean = multivariateNormalDistribution.getMean();
        this.rootPriorPrecision = multivariateNormalDistribution.getScaleMatrix();
        this.sampleRoot = true;
    }

    public void setTaxonPrior(Taxon taxon, MultivariateDistribution multivariateDistribution) {
        if (multivariateDistribution instanceof GeoSpatialDistribution) {
            if (this.nodeGeoSpatialPrior == null) {
                this.nodeGeoSpatialPrior = new HashMap<Taxon, GeoSpatialDistribution>();
            }
            this.nodeGeoSpatialPrior.put(taxon, (GeoSpatialDistribution)multivariateDistribution);
        } else if (multivariateDistribution instanceof MultivariateNormalDistribution) {
            if (this.nodeMVNPrior == null) {
                this.nodeMVNPrior = new HashMap<Taxon, MultivariateNormalDistribution>();
            }
            this.nodeMVNPrior.put(taxon, (MultivariateNormalDistribution)multivariateDistribution);
        } else {
            throw new RuntimeException("Only flat/truncated geospatial and multivariate normal distributions allowed");
        }
    }

    public void setParameterPrior(GeoSpatialCollectionModel geoSpatialCollectionModel) {
        this.parameterPrior = geoSpatialCollectionModel;
    }

    public int getStepCount() {
        return 1;
    }

    private boolean nodeGeoSpatialPriorExists(NodeRef nodeRef) {
        return this.nodeGeoSpatialPrior != null && this.nodeGeoSpatialPrior.containsKey(this.treeModel.getNodeTaxon(nodeRef));
    }

    private boolean nodeMVNPriorExists(NodeRef nodeRef) {
        return this.nodeMVNPrior != null && this.nodeMVNPrior.containsKey(this.treeModel.getNodeTaxon(nodeRef));
    }

    @Override
    public double doOperation() {
        boolean bl;
        NodeRef nodeRef = null;
        NodeRef nodeRef2 = this.treeModel.getRoot();
        while (nodeRef == null) {
            if (this.onlyInternalNodes) {
                nodeRef = this.treeModel.getInternalNode(MathUtils.nextInt(this.treeModel.getInternalNodeCount()));
            } else {
                nodeRef = this.treeModel.getNode(MathUtils.nextInt(this.treeModel.getNodeCount()));
                if (this.onlyTipsWithPriors && this.treeModel.getChildCount(nodeRef) == 0 && !this.nodeGeoSpatialPriorExists(nodeRef)) {
                    nodeRef = null;
                }
            }
            if (this.sampleRoot || nodeRef != nodeRef2) continue;
            nodeRef = null;
        }
        double[] dArray = this.treeModel.getMultivariateNodeTrait(nodeRef, this.traitName);
        MeanPrecision meanPrecision = nodeRef != nodeRef2 ? this.operateNotRoot(nodeRef) : this.operateRoot(nodeRef);
        Taxon taxon = this.treeModel.getNodeTaxon(nodeRef);
        boolean bl2 = this.nodeGeoSpatialPriorExists(nodeRef);
        int n = 0;
        boolean bl3 = bl = this.parameterPrior != null;
        while (true) {
            if (n > 10000) {
                this.treeModel.setMultivariateTrait(nodeRef, this.traitName, dArray);
                throw new RuntimeException("Truncated Gibbs is stuck!");
            }
            double[] dArray2 = MultivariateNormalDistribution.nextMultivariateNormalPrecision(meanPrecision.mean, meanPrecision.precision);
            ++n;
            if (bl2 && this.nodeGeoSpatialPrior.get(taxon).logPdf(dArray2) == Double.NEGATIVE_INFINITY) continue;
            this.treeModel.setMultivariateTrait(nodeRef, this.traitName, dArray2);
            if (!bl || this.parameterPrior.getLogLikelihood() != Double.NEGATIVE_INFINITY) break;
        }
        return 0.0;
    }

    private MeanPrecision operateNotRoot(NodeRef nodeRef) {
        int n;
        double[][] dArray = this.precisionMatrixParameter.getParameterAsMatrix();
        NodeRef nodeRef2 = this.treeModel.getParent(nodeRef);
        double[] dArray2 = new double[this.dim];
        double d = 1.0 / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef);
        double[] dArray3 = this.treeModel.getMultivariateNodeTrait(nodeRef2, this.traitName);
        for (int i = 0; i < this.dim; ++i) {
            dArray2[i] = dArray3[i] * d;
        }
        double d2 = d;
        for (n = 0; n < this.treeModel.getChildCount(nodeRef); ++n) {
            NodeRef nodeRef3 = this.treeModel.getChild(nodeRef, n);
            dArray3 = this.treeModel.getMultivariateNodeTrait(nodeRef3, this.traitName);
            d = 1.0 / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef3);
            for (int i = 0; i < this.dim; ++i) {
                int n2 = i;
                dArray2[n2] = dArray2[n2] + dArray3[i] * d;
            }
            d2 += d;
        }
        for (n = 0; n < this.dim; ++n) {
            int n3 = n;
            dArray2[n3] = dArray2[n3] / d2;
            for (int i = n; i < this.dim; ++i) {
                double[] dArray4 = dArray[n];
                int n4 = i;
                double d3 = dArray4[n4] * d2;
                dArray4[n4] = d3;
                dArray[i][n] = d3;
            }
        }
        if (this.nodeMVNPriorExists(nodeRef)) {
            throw new RuntimeException("Still trying to implement multivariate normal taxon priors");
        }
        return new MeanPrecision(dArray2, dArray);
    }

    private MeanPrecision operateRoot(NodeRef nodeRef) {
        double[] dArray;
        int n;
        double d = 0.0;
        double[] dArray2 = new double[this.dim];
        double[][] dArray3 = this.precisionMatrixParameter.getParameterAsMatrix();
        for (n = 0; n < this.treeModel.getChildCount(nodeRef); ++n) {
            NodeRef nodeRef2 = this.treeModel.getChild(nodeRef, n);
            dArray = this.treeModel.getMultivariateNodeTrait(nodeRef2, this.traitName);
            double d2 = 1.0 / this.traitModel.getRescaledBranchLengthForPrecision(nodeRef2);
            for (int i = 0; i < this.dim; ++i) {
                for (int j = 0; j < this.dim; ++j) {
                    int n2 = i;
                    dArray2[n2] = dArray2[n2] + dArray3[i][j] * d2 * dArray[j];
                }
            }
            d += d2;
        }
        for (n = 0; n < this.dim; ++n) {
            for (int i = 0; i < this.dim; ++i) {
                int n3 = n;
                dArray2[n3] = dArray2[n3] + this.rootPriorPrecision[n][i] * this.rootPriorMean[i];
                dArray3[n][i] = dArray3[n][i] * d + this.rootPriorPrecision[n][i];
            }
        }
        double[][] dArray4 = new SymmetricMatrix(dArray3).inverse().toComponents();
        dArray = new double[this.dim];
        for (int i = 0; i < this.dim; ++i) {
            for (int j = 0; j < this.dim; ++j) {
                int n4 = i;
                dArray[n4] = dArray[n4] + dArray4[i][j] * dArray2[j];
            }
        }
        return new MeanPrecision(dArray, dArray3);
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override
    public String getOperatorName() {
        return GIBBS_OPERATOR;
    }

    class MeanPrecision {
        final double[] mean;
        final double[][] precision;

        MeanPrecision(double[] dArray, double[][] dArray2) {
            this.mean = dArray;
            this.precision = dArray2;
        }
    }
}

