/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.classify;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureSelection;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelVector;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;

public class MaxEnt
extends Classifier
implements Serializable {
    protected double[] parameters;
    protected int defaultFeatureIndex;
    protected FeatureSelection featureSelection;
    protected FeatureSelection[] perClassFeatureSelection;
    private static final long serialVersionUID = 1L;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;

    public MaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection featureSelection, FeatureSelection[] perClassFeatureSelection) {
        super(dataPipe);
        assert (featureSelection == null || perClassFeatureSelection == null);
        this.parameters = parameters != null ? parameters : new double[MaxEnt.getNumParameters(dataPipe)];
        this.featureSelection = featureSelection;
        this.perClassFeatureSelection = perClassFeatureSelection;
        this.defaultFeatureIndex = dataPipe.getDataAlphabet().size();
    }

    public MaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection featureSelection) {
        this(dataPipe, parameters, featureSelection, null);
    }

    public MaxEnt(Pipe dataPipe, double[] parameters, FeatureSelection[] perClassFeatureSelection) {
        this(dataPipe, parameters, null, perClassFeatureSelection);
    }

    public MaxEnt(Pipe dataPipe, double[] parameters) {
        this(dataPipe, parameters, null, null);
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public int getNumParameters() {
        assert (this.instancePipe.getDataAlphabet() != null);
        assert (this.instancePipe.getTargetAlphabet() != null);
        return MaxEnt.getNumParameters(this.instancePipe);
    }

    public static int getNumParameters(Pipe instancePipe) {
        return (instancePipe.getDataAlphabet().size() + 1) * instancePipe.getTargetAlphabet().size();
    }

    public void setParameters(double[] parameters) {
        this.parameters = parameters;
    }

    public void setParameter(int classIndex, int featureIndex, double value) {
        this.parameters[classIndex * (this.getAlphabet().size() + 1) + featureIndex] = value;
    }

    @Override
    public FeatureSelection getFeatureSelection() {
        return this.featureSelection;
    }

    public MaxEnt setFeatureSelection(FeatureSelection fs) {
        this.featureSelection = fs;
        return this;
    }

    @Override
    public FeatureSelection[] getPerClassFeatureSelection() {
        return this.perClassFeatureSelection;
    }

    public MaxEnt setPerClassFeatureSelection(FeatureSelection[] fss) {
        this.perClassFeatureSelection = fss;
        return this;
    }

    public int getDefaultFeatureIndex() {
        return this.defaultFeatureIndex;
    }

    public void setDefaultFeatureIndex(int defaultFeatureIndex) {
        this.defaultFeatureIndex = defaultFeatureIndex;
    }

    public void getUnnormalizedClassificationScores(Instance instance, double[] scores) {
        int numFeatures = this.defaultFeatureIndex + 1;
        int numLabels = this.getLabelAlphabet().size();
        assert (scores.length == numLabels);
        FeatureVector fv = (FeatureVector)instance.getData();
        assert (fv.getAlphabet() == this.instancePipe.getDataAlphabet());
        for (int li = 0; li < numLabels; ++li) {
            scores[li] = this.parameters[li * numFeatures + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, numFeatures, li, fv, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[li]);
        }
    }

    public void getClassificationScores(Instance instance, double[] scores) {
        int li;
        this.getUnnormalizedClassificationScores(instance, scores);
        int numLabels = this.getLabelAlphabet().size();
        double max = MatrixOps.max(scores);
        double sum = 0.0;
        for (li = 0; li < numLabels; ++li) {
            scores[li] = Math.exp(scores[li] - max);
            sum += scores[li];
        }
        li = 0;
        while (li < numLabels) {
            int n = li++;
            scores[n] = scores[n] / sum;
        }
    }

    public void getClassificationScoresWithTemperature(Instance instance, double temperature, double[] scores) {
        int li;
        this.getUnnormalizedClassificationScores(instance, scores);
        MatrixOps.timesEquals(scores, 1.0 / temperature);
        int numLabels = this.getLabelAlphabet().size();
        double max = MatrixOps.max(scores);
        double sum = 0.0;
        for (li = 0; li < numLabels; ++li) {
            scores[li] = Math.exp(scores[li] - max);
            sum += scores[li];
        }
        li = 0;
        while (li < numLabels) {
            int n = li++;
            scores[n] = scores[n] / sum;
        }
    }

    @Override
    public Classification classify(Instance instance) {
        int numClasses = this.getLabelAlphabet().size();
        double[] scores = new double[numClasses];
        this.getClassificationScores(instance, scores);
        return new Classification(instance, this, new LabelVector(this.getLabelAlphabet(), scores));
    }

    @Override
    public void print() {
        this.print(System.out);
    }

    @Override
    public void print(PrintWriter out) {
        Alphabet dict = this.getAlphabet();
        LabelAlphabet labelDict = this.getLabelAlphabet();
        int numFeatures = dict.size() + 1;
        int numLabels = labelDict.size();
        for (int li = 0; li < numLabels; ++li) {
            out.println("FEATURES FOR CLASS " + labelDict.lookupObject(li));
            out.println(" <default> " + this.parameters[li * numFeatures + this.defaultFeatureIndex]);
            for (int i = 0; i < this.defaultFeatureIndex; ++i) {
                Object name = dict.lookupObject(i);
                double weight = this.parameters[li * numFeatures + i];
                out.println(" " + name + " " + weight);
            }
        }
    }

    public void print(PrintStream out) {
        this.print(new PrintWriter(out));
    }

    public void printRank(PrintWriter out) {
        Alphabet dict = this.getAlphabet();
        LabelAlphabet labelDict = this.getLabelAlphabet();
        int numFeatures = dict.size() + 1;
        int numLabels = labelDict.size();
        double[] weights = new double[numFeatures - 1];
        for (int li = 0; li < numLabels; ++li) {
            out.print("FEATURES FOR CLASS " + labelDict.lookupObject(li) + " ");
            for (int i = 0; i < this.defaultFeatureIndex; ++i) {
                double weight;
                weights[i] = weight = this.parameters[li * numFeatures + i];
            }
            RankedFeatureVector rfv = new RankedFeatureVector(dict, weights);
            rfv.printByRank(out);
            out.println(" <default> " + this.parameters[li * numFeatures + this.defaultFeatureIndex] + " ");
        }
    }

    public void printExtremeFeatures(PrintWriter out, int num) {
        Alphabet dict = this.getAlphabet();
        LabelAlphabet labelDict = this.getLabelAlphabet();
        int numFeatures = dict.size() + 1;
        int numLabels = labelDict.size();
        double[] weights = new double[numFeatures - 1];
        for (int li = 0; li < numLabels; ++li) {
            out.print("FEATURES FOR CLASS " + labelDict.lookupObject(li) + " ");
            for (int i = 0; i < this.defaultFeatureIndex; ++i) {
                double weight;
                Object name = dict.lookupObject(i);
                weights[i] = weight = this.parameters[li * numFeatures + i];
            }
            RankedFeatureVector rfv = new RankedFeatureVector(dict, weights);
            rfv.printTopK(out, num);
            out.print(" <default> " + this.parameters[li * numFeatures + this.defaultFeatureIndex] + " ");
            rfv.printLowerK(out, num);
            out.println();
        }
    }

    private void writeObject(ObjectOutputStream out) throws IOException {
        out.writeInt(1);
        out.writeObject(this.getInstancePipe());
        int np = this.parameters.length;
        out.writeInt(np);
        for (int p = 0; p < np; ++p) {
            out.writeDouble(this.parameters[p]);
        }
        out.writeInt(this.defaultFeatureIndex);
        if (this.featureSelection == null) {
            out.writeInt(-1);
        } else {
            out.writeInt(1);
            out.writeObject(this.featureSelection);
        }
        if (this.perClassFeatureSelection == null) {
            out.writeInt(-1);
        } else {
            out.writeInt(this.perClassFeatureSelection.length);
            for (int i = 0; i < this.perClassFeatureSelection.length; ++i) {
                if (this.perClassFeatureSelection[i] == null) {
                    out.writeInt(-1);
                    continue;
                }
                out.writeInt(1);
                out.writeObject(this.perClassFeatureSelection[i]);
            }
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        int nfs;
        int version = in.readInt();
        if (version != 1) {
            throw new ClassNotFoundException("Mismatched MaxEnt versions: wanted 1, got " + version);
        }
        this.instancePipe = (Pipe)in.readObject();
        int np = in.readInt();
        this.parameters = new double[np];
        for (int p = 0; p < np; ++p) {
            this.parameters[p] = in.readDouble();
        }
        this.defaultFeatureIndex = in.readInt();
        int opt = in.readInt();
        if (opt == 1) {
            this.featureSelection = (FeatureSelection)in.readObject();
        }
        if ((nfs = in.readInt()) >= 0) {
            this.perClassFeatureSelection = new FeatureSelection[nfs];
            for (int i = 0; i < nfs; ++i) {
                opt = in.readInt();
                if (opt != 1) continue;
                this.perClassFeatureSelection[i] = (FeatureSelection)in.readObject();
            }
        }
    }
}

