/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import gnu.trove.TIntIntHashMap;
import java.io.File;

public class RTopicModel {
    public InstanceList instances = null;
    public ParallelTopicModel model = null;

    public RTopicModel(double numTopics, double alpha, double beta) {
        this.model = new ParallelTopicModel((int)Math.floor(numTopics), alpha, beta);
    }

    public void loadDocuments(String filename) {
        this.instances = InstanceList.load(new File(filename));
        this.model.addInstances(this.instances);
    }

    public void loadDocuments(InstanceList instances) {
        this.instances = instances;
        this.model.addInstances(instances);
    }

    public static void addInstance(InstanceList instances, String id, String text) {
        instances.addThruPipe(new Instance(text, null, id, null));
    }

    public static void addInstances(InstanceList instances, String[] ids, String[] texts) {
        for (int i = 0; i < ids.length; ++i) {
            instances.addThruPipe(new Instance(texts[i], null, ids[i], null));
        }
    }

    public void setAlphaOptimization(double frequency, double burnin) {
        this.model.setBurninPeriod((int)Math.floor(burnin));
        this.model.setOptimizeInterval((int)Math.floor(frequency));
    }

    public void train(double numIterations) {
        try {
            this.model.setNumIterations((int)Math.floor(numIterations));
            this.model.estimate();
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    public void maximize(double numIterations) {
        this.model.maximize((int)Math.floor(numIterations));
    }

    public double[] getAlpha() {
        return this.model.alpha;
    }

    public String[] getVocabulary() {
        String[] vocab = new String[this.model.alphabet.size()];
        for (int type = 0; type < this.model.numTypes; ++type) {
            vocab[type] = (String)this.model.alphabet.lookupObject(type);
        }
        return vocab;
    }

    public String[] getDocumentNames() {
        String[] docNames = new String[this.model.data.size()];
        for (int doc = 0; doc < docNames.length; ++doc) {
            docNames[doc] = (String)this.model.data.get((int)doc).instance.getName();
        }
        return docNames;
    }

    public double[][] getSubCorpusTopicWords(boolean[] documentMask, boolean normalized, boolean smoothed) {
        double[][] result;
        block11: {
            int topic;
            result = new double[this.model.numTopics][this.model.numTypes];
            int[] subCorpusTokensPerTopic = new int[this.model.numTopics];
            for (int doc = 0; doc < this.model.data.size(); ++doc) {
                if (!documentMask[doc]) continue;
                int[] words = ((FeatureSequence)this.model.data.get((int)doc).instance.getData()).getFeatures();
                int[] topics = this.model.data.get((int)doc).topicSequence.getFeatures();
                for (int position = 0; position < topics.length; ++position) {
                    double[] dArray = result[topics[position]];
                    int n = words[position];
                    dArray[n] = dArray[n] + 1.0;
                    int n2 = topics[position];
                    subCorpusTokensPerTopic[n2] = subCorpusTokensPerTopic[n2] + 1;
                }
            }
            if (smoothed) {
                for (int topic2 = 0; topic2 < this.model.numTopics; ++topic2) {
                    int type = 0;
                    while (type < this.model.numTypes) {
                        double[] dArray = result[topic2];
                        int n = type++;
                        dArray[n] = dArray[n] + this.model.beta;
                    }
                }
            }
            if (!normalized) break block11;
            double[] topicNormalizers = new double[this.model.numTopics];
            if (smoothed) {
                for (topic = 0; topic < this.model.numTopics; ++topic) {
                    topicNormalizers[topic] = 1.0 / ((double)subCorpusTokensPerTopic[topic] + (double)this.model.numTypes * this.model.beta);
                }
            } else {
                for (topic = 0; topic < this.model.numTopics; ++topic) {
                    topicNormalizers[topic] = 1.0 / (double)subCorpusTokensPerTopic[topic];
                }
            }
            for (topic = 0; topic < this.model.numTopics; ++topic) {
                int type = 0;
                while (type < this.model.numTypes) {
                    double[] dArray = result[topic];
                    int n = type++;
                    dArray[n] = dArray[n] * topicNormalizers[topic];
                }
            }
        }
        return result;
    }

    public double[][] getTopicWords(boolean normalized, boolean smoothed) {
        double[][] result;
        block11: {
            int topic;
            result = new double[this.model.numTopics][this.model.numTypes];
            for (int type = 0; type < this.model.numTypes; ++type) {
                int[] topicCounts = this.model.typeTopicCounts[type];
                for (int index = 0; index < topicCounts.length && topicCounts[index] > 0; ++index) {
                    int topic2 = topicCounts[index] & this.model.topicMask;
                    int count = topicCounts[index] >> this.model.topicBits;
                    double[] dArray = result[topic2];
                    int n = type;
                    dArray[n] = dArray[n] + (double)count;
                }
            }
            if (smoothed) {
                for (int topic3 = 0; topic3 < this.model.numTopics; ++topic3) {
                    int type = 0;
                    while (type < this.model.numTypes) {
                        double[] dArray = result[topic3];
                        int n = type++;
                        dArray[n] = dArray[n] + this.model.beta;
                    }
                }
            }
            if (!normalized) break block11;
            double[] topicNormalizers = new double[this.model.numTopics];
            if (smoothed) {
                for (topic = 0; topic < this.model.numTopics; ++topic) {
                    topicNormalizers[topic] = 1.0 / ((double)this.model.tokensPerTopic[topic] + (double)this.model.numTypes * this.model.beta);
                }
            } else {
                for (topic = 0; topic < this.model.numTopics; ++topic) {
                    topicNormalizers[topic] = 1.0 / (double)this.model.tokensPerTopic[topic];
                }
            }
            for (topic = 0; topic < this.model.numTopics; ++topic) {
                int type = 0;
                while (type < this.model.numTypes) {
                    double[] dArray = result[topic];
                    int n = type++;
                    dArray[n] = dArray[n] * topicNormalizers[topic];
                }
            }
        }
        return result;
    }

    public double[][] getDocumentTopics(boolean normalized, boolean smoothed) {
        double[][] result = new double[this.model.data.size()][this.model.numTopics];
        for (int doc = 0; doc < this.model.data.size(); ++doc) {
            int[] topics = this.model.data.get((int)doc).topicSequence.getFeatures();
            for (int position = 0; position < topics.length; ++position) {
                double[] dArray = result[doc];
                int n = topics[position];
                dArray[n] = dArray[n] + 1.0;
            }
            if (smoothed) {
                for (int topic = 0; topic < this.model.numTopics; ++topic) {
                    double[] dArray = result[doc];
                    int n = topic;
                    dArray[n] = dArray[n] + this.model.alpha[topic];
                }
            }
            if (!normalized) continue;
            double sum = 0.0;
            for (int topic = 0; topic < this.model.numTopics; ++topic) {
                sum += result[doc][topic];
            }
            double normalizer = 1.0 / sum;
            int topic = 0;
            while (topic < this.model.numTopics) {
                double[] dArray = result[doc];
                int n = topic++;
                dArray[n] = dArray[n] * normalizer;
            }
        }
        return result;
    }

    public double[][] getWordFrequencies() {
        if (this.instances == null) {
            throw new IllegalStateException("You must load instances before you can count features");
        }
        double[][] result = new double[this.model.numTypes][2];
        TIntIntHashMap docCounts = new TIntIntHashMap();
        for (Instance instance : this.instances) {
            FeatureSequence features = (FeatureSequence)instance.getData();
            for (int i = 0; i < features.getLength(); ++i) {
                docCounts.adjustOrPutValue(features.getIndexAtPosition(i), 1, 1);
            }
            int[] keys = docCounts.keys();
            for (int i = 0; i < keys.length - 1; ++i) {
                int feature = keys[i];
                double[] dArray = result[feature];
                dArray[0] = dArray[0] + (double)docCounts.get(feature);
                double[] dArray2 = result[feature];
                dArray2[1] = dArray2[1] + 1.0;
            }
            docCounts = new TIntIntHashMap();
        }
        return result;
    }

    public void writeState(String filename) {
        try {
            this.model.printState(new File(filename));
        }
        catch (Exception e) {
            System.err.println(e);
        }
    }
}

