/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.meta;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.MultiChoiceOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.AbstractMOAObject;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.Regressor;
import moa.classifiers.core.driftdetection.ChangeDetector;
import moa.classifiers.trees.ARFFIMTDD;
import moa.core.DoubleVector;
import moa.core.Example;
import moa.core.InstanceExample;
import moa.core.Measurement;
import moa.core.MiscUtils;
import moa.evaluation.BasicRegressionPerformanceEvaluator;
import moa.options.ClassOption;

public class AdaptiveRandomForestRegressor
extends AbstractClassifier
implements Regressor {
    private static final long serialVersionUID = 1L;
    public ClassOption treeLearnerOption = new ClassOption("treeLearner", 'l', "Random Forest Tree.", ARFFIMTDD.class, "ARFFIMTDD");
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of trees.", 10, 1, Integer.MAX_VALUE);
    public MultiChoiceOption mFeaturesModeOption = new MultiChoiceOption("mFeaturesMode", 'o', "Defines how m, defined by mFeaturesPerTreeSize, is interpreted. M represents the total number of features.", new String[]{"Specified m (integer value)", "sqrt(M)+1", "M-(sqrt(M)+1)", "Percentage (M * (m / 100))"}, new String[]{"SpecifiedM", "SqrtM1", "MSqrtM1", "Percentage"}, 1);
    public IntOption mFeaturesPerTreeSizeOption = new IntOption("mFeaturesPerTreeSize", 'm', "Number of features allowed considered for each split. Negative values corresponds to M - m", 2, Integer.MIN_VALUE, Integer.MAX_VALUE);
    public FloatOption lambdaOption = new FloatOption("lambda", 'a', "The lambda parameter for bagging.", 6.0, 1.0, 3.4028234663852886E38);
    public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'x', "Change detector for drifts and its parameters", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-5");
    public ClassOption warningDetectionMethodOption = new ClassOption("warningDetectionMethod", 'p', "Change detector for warnings (start training bkg learner)", ChangeDetector.class, "ADWINChangeDetector -a 1.0E-4");
    public FlagOption disableWeightedVote = new FlagOption("disableWeightedVote", 'w', "Should use weighted voting?");
    public FlagOption disableDriftDetectionOption = new FlagOption("disableDriftDetection", 'u', "Should use drift detection? If disabled then bkg learner is also disabled");
    public FlagOption disableBackgroundLearnerOption = new FlagOption("disableBackgroundLearner", 'q', "Should use bkg learner? If disabled then reset tree immediately.");
    protected static final int FEATURES_M = 0;
    protected static final int FEATURES_SQRT = 1;
    protected static final int FEATURES_SQRT_INV = 2;
    protected static final int FEATURES_PERCENT = 3;
    protected ARFFIMTDDBaseLearner[] ensemble;
    protected long instancesSeen;
    protected int subspaceSize;
    protected BasicRegressionPerformanceEvaluator evaluator;

    @Override
    public String getPurposeString() {
        return "Adaptive Random Forest Regressor algorithm for evolving data streams from Gomes et al.";
    }

    @Override
    public void resetLearningImpl() {
        this.ensemble = null;
        this.subspaceSize = 0;
        this.instancesSeen = 0L;
        this.evaluator = new BasicRegressionPerformanceEvaluator();
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        ++this.instancesSeen;
        if (this.ensemble == null) {
            this.initEnsemble(instance);
        }
        for (int i = 0; i < this.ensemble.length; ++i) {
            DoubleVector vote = new DoubleVector(this.ensemble[i].getVotesForInstance(instance));
            InstanceExample example = new InstanceExample(instance);
            this.ensemble[i].evaluator.addResult((Example<Instance>)example, vote.getArrayRef());
            int k = MiscUtils.poisson(this.lambdaOption.getValue(), this.classifierRandom);
            if (k <= 0) continue;
            this.ensemble[i].trainOnInstance(instance, k, this.instancesSeen);
        }
    }

    @Override
    public double[] getVotesForInstance(Instance instance) {
        Instance testInstance = instance.copy();
        if (this.ensemble == null) {
            this.initEnsemble(testInstance);
        }
        double accounted = 0.0;
        DoubleVector predictions = new DoubleVector();
        DoubleVector ages = new DoubleVector();
        DoubleVector performance = new DoubleVector();
        for (int i = 0; i < this.ensemble.length; ++i) {
            double currentPrediction = this.ensemble[i].getVotesForInstance(testInstance)[0];
            ages.addToValue(i, this.instancesSeen - this.ensemble[i].createdOn);
            performance.addToValue(i, this.ensemble[i].evaluator.getSquareError());
            predictions.addToValue(i, currentPrediction);
            accounted += 1.0;
        }
        double predicted = predictions.sumOfValues() / accounted;
        return new double[]{predictions.sumOfValues() / accounted};
    }

    @Override
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    protected void initEnsemble(Instance instance) {
        int ensembleSize = this.ensembleSizeOption.getValue();
        this.ensemble = new ARFFIMTDDBaseLearner[ensembleSize];
        BasicRegressionPerformanceEvaluator regressionEvaluator = new BasicRegressionPerformanceEvaluator();
        this.subspaceSize = this.mFeaturesPerTreeSizeOption.getValue();
        int n = instance.numAttributes() - 1;
        switch (this.mFeaturesModeOption.getChosenIndex()) {
            case 1: {
                this.subspaceSize = (int)Math.round(Math.sqrt(n)) + 1;
                break;
            }
            case 2: {
                this.subspaceSize = n - (int)Math.round(Math.sqrt(n) + 1.0);
                break;
            }
            case 3: {
                double percent = this.subspaceSize < 0 ? (double)(100 + this.subspaceSize) / 100.0 : (double)this.subspaceSize / 100.0;
                this.subspaceSize = (int)Math.round((double)n * percent);
            }
        }
        if (this.subspaceSize < 0) {
            this.subspaceSize = n + this.subspaceSize;
        }
        if (this.subspaceSize <= 0) {
            this.subspaceSize = 1;
        }
        if (this.subspaceSize > n) {
            this.subspaceSize = n;
        }
        ARFFIMTDD treeLearner = (ARFFIMTDD)this.getPreparedClassOption(this.treeLearnerOption);
        treeLearner.resetLearning();
        for (int i = 0; i < ensembleSize; ++i) {
            treeLearner.subspaceSizeOption.setValue(this.subspaceSize);
            this.ensemble[i] = new ARFFIMTDDBaseLearner(i, (ARFFIMTDD)treeLearner.copy(), (BasicRegressionPerformanceEvaluator)regressionEvaluator.copy(), this.instancesSeen, !this.disableBackgroundLearnerOption.isSet(), !this.disableDriftDetectionOption.isSet(), this.driftDetectionMethodOption, this.warningDetectionMethodOption, false);
        }
    }

    @Override
    public void getModelDescription(StringBuilder out, int indent) {
    }

    @Override
    public boolean isRandomizable() {
        return true;
    }

    protected final class ARFFIMTDDBaseLearner
    extends AbstractMOAObject {
        public int indexOriginal;
        public long createdOn;
        public long lastDriftOn;
        public long lastWarningOn;
        public ARFFIMTDD classifier;
        public boolean isBackgroundLearner;
        protected ClassOption driftOption;
        protected ClassOption warningOption;
        protected ChangeDetector driftDetectionMethod;
        protected ChangeDetector warningDetectionMethod;
        public boolean useBkgLearner;
        public boolean useDriftDetector;
        protected ARFFIMTDDBaseLearner bkgLearner;
        public BasicRegressionPerformanceEvaluator evaluator;
        protected int numberOfDriftsDetected;
        protected int numberOfWarningsDetected;

        private void init(int indexOriginal, ARFFIMTDD instantiatedClassifier, BasicRegressionPerformanceEvaluator evaluatorInstantiated, long instancesSeen, boolean useBkgLearner, boolean useDriftDetector, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) {
            this.indexOriginal = indexOriginal;
            this.createdOn = instancesSeen;
            this.lastDriftOn = 0L;
            this.lastWarningOn = 0L;
            this.classifier = instantiatedClassifier;
            this.evaluator = evaluatorInstantiated;
            this.useBkgLearner = useBkgLearner;
            this.useDriftDetector = useDriftDetector;
            this.numberOfDriftsDetected = 0;
            this.numberOfWarningsDetected = 0;
            this.isBackgroundLearner = isBackgroundLearner;
            if (this.useDriftDetector) {
                this.driftOption = driftOption;
                this.driftDetectionMethod = ((ChangeDetector)AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.driftOption)).copy();
            }
            if (this.useBkgLearner) {
                this.warningOption = warningOption;
                this.warningDetectionMethod = ((ChangeDetector)AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.warningOption)).copy();
            }
        }

        public ARFFIMTDDBaseLearner(int indexOriginal, ARFFIMTDD instantiatedClassifier, BasicRegressionPerformanceEvaluator evaluatorInstantiated, long instancesSeen, boolean useBkgLearner, boolean useDriftDetector, ClassOption driftOption, ClassOption warningOption, boolean isBackgroundLearner) {
            this.init(indexOriginal, instantiatedClassifier, evaluatorInstantiated, instancesSeen, useBkgLearner, useDriftDetector, driftOption, warningOption, isBackgroundLearner);
        }

        public void reset() {
            if (this.useBkgLearner && this.bkgLearner != null) {
                this.classifier = this.bkgLearner.classifier;
                this.driftDetectionMethod = this.bkgLearner.driftDetectionMethod;
                this.warningDetectionMethod = this.bkgLearner.warningDetectionMethod;
                this.evaluator = this.bkgLearner.evaluator;
                this.createdOn = this.bkgLearner.createdOn;
                this.bkgLearner = null;
            } else {
                this.classifier.resetLearning();
                this.createdOn = AdaptiveRandomForestRegressor.this.instancesSeen;
                this.driftDetectionMethod = ((ChangeDetector)AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.driftOption)).copy();
            }
            this.evaluator.reset();
        }

        public void trainOnInstance(Instance instance, double weight, long instancesSeen) {
            Instance weightedInstance = instance.copy();
            weightedInstance.setWeight(instance.weight() * weight);
            this.classifier.trainOnInstance(weightedInstance);
            if (this.bkgLearner != null) {
                this.bkgLearner.classifier.trainOnInstance(instance);
            }
            if (this.useDriftDetector && !this.isBackgroundLearner) {
                double prediction = this.classifier.getVotesForInstance(instance)[0];
                if (this.useBkgLearner) {
                    this.warningDetectionMethod.input(prediction);
                    if (this.warningDetectionMethod.getChange()) {
                        this.lastWarningOn = instancesSeen;
                        ++this.numberOfWarningsDetected;
                        ARFFIMTDD bkgClassifier = (ARFFIMTDD)this.classifier.copy();
                        bkgClassifier.resetLearning();
                        BasicRegressionPerformanceEvaluator bkgEvaluator = (BasicRegressionPerformanceEvaluator)this.evaluator.copy();
                        bkgEvaluator.reset();
                        this.bkgLearner = new ARFFIMTDDBaseLearner(this.indexOriginal, bkgClassifier, bkgEvaluator, instancesSeen, this.useBkgLearner, this.useDriftDetector, this.driftOption, this.warningOption, true);
                        this.warningDetectionMethod = ((ChangeDetector)AdaptiveRandomForestRegressor.this.getPreparedClassOption(this.warningOption)).copy();
                    }
                }
                this.driftDetectionMethod.input(prediction);
                if (this.driftDetectionMethod.getChange()) {
                    this.lastDriftOn = instancesSeen;
                    ++this.numberOfDriftsDetected;
                    this.reset();
                }
            }
        }

        public double[] getVotesForInstance(Instance instance) {
            return this.classifier.getVotesForInstance(instance);
        }

        @Override
        public void getDescription(StringBuilder sb, int indent) {
        }
    }
}

