/*
 * Decompiled with CFR 0.152.
 */
package bartMachine;

import OpenSourceExtensions.StatUtil;
import bartMachine.StatToolbox;
import bartMachine.TreeArrayIllustration;
import bartMachine.bartMachineRegression;
import bartMachine.bartMachineTreeNode;
import java.io.Serializable;

public class bartMachineClassification
extends bartMachineRegression
implements Serializable {
    private static final double SIGSQ_FOR_PROBIT = 1.0;

    @Override
    protected void DoOneGibbsSample() {
        bartMachineTreeNode[] bartMachineTreeNodeArray = new bartMachineTreeNode[this.num_trees];
        TreeArrayIllustration treeArrayIllustration = new TreeArrayIllustration(this.gibbs_sample_num, this.unique_name);
        this.SampleZs();
        for (int i = 0; i < this.num_trees; ++i) {
            if (this.verbose) {
                this.GibbsSampleDebugMessage(i);
            }
            this.SampleTree(this.gibbs_sample_num, i, bartMachineTreeNodeArray, treeArrayIllustration);
            this.SampleMusWrapper(this.gibbs_sample_num, i);
        }
    }

    private void SampleZs() {
        for (int i = 0; i < this.n; ++i) {
            double d = 0.0;
            bartMachineTreeNode[] bartMachineTreeNodeArray = this.gibbs_samples_of_bart_trees[this.gibbs_sample_num - 1];
            for (int j = 0; j < this.num_trees; ++j) {
                d += bartMachineTreeNodeArray[j].Evaluate((double[])this.X_y.get(i));
            }
            this.y_trans[i] = this.SampleZi(d, this.y_orig[i]);
        }
    }

    private double SampleZi(double d, double d2) {
        double d3 = StatToolbox.rand();
        if (d2 == 1.0) {
            return d + StatUtil.getInvCDF((1.0 - d3) * StatToolbox.normal_cdf(-d) + d3, false);
        }
        if (d2 == 0.0) {
            return d - StatUtil.getInvCDF((1.0 - d3) * StatToolbox.normal_cdf(d) + d3, false);
        }
        System.err.println("SampleZi RESPONSE NOT ZERO / ONE");
        System.exit(0);
        return -1.0;
    }

    @Override
    protected void SetupGibbsSampling() {
        super.SetupGibbsSampling();
        for (int i = 0; i < this.num_gibbs_total_iterations; ++i) {
            this.gibbs_samples_of_sigsq[i] = 1.0;
        }
    }

    @Override
    protected void calculateHyperparameters() {
        this.hyper_mu_mu = 0.0;
        this.hyper_sigsq_mu = Math.pow(3.0 / (this.hyper_k * Math.sqrt(this.num_trees)), 2.0);
    }

    @Override
    protected void transformResponseVariable() {
        this.y_trans = new double[this.y_orig.length];
    }

    @Override
    public double un_transform_y(double d) {
        return d;
    }
}

