/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Discretize;
import weka.filters.unsupervised.attribute.MultiInstanceToPropositional;

public class MIBoost
extends SingleClassifierEnhancer
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -3808427225599279539L;
    protected Classifier[] m_Models;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected Instances m_Attributes;
    private int m_NumIterations = 100;
    protected double[] m_Beta;
    protected int m_MaxIterations = 10;
    protected int m_DiscretizeBin = 0;
    protected Discretize m_Filter = null;
    protected MultiInstanceToPropositional m_ConvertToSI = new MultiInstanceToPropositional();

    public String globalInfo() {
        return "MI AdaBoost method, considers the geometric mean of posterior of instances inside a bag (arithmatic mean of log-posterior) and the expectation for a bag is taken inside the loss function.\n\nFor more information about Adaboost, see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Experiments with a new boosting algorithm");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1996");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "148-156");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "San Francisco");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tTurn on debugging output.", "D", 0, "-D"));
        vector.addElement(new Option("\tThe number of bins in discretization\n\t(default 0, no discretization)", "B", 1, "-B <num>"));
        vector.addElement(new Option("\tMaximum number of boost iterations.\n\t(default 10)", "R", 1, "-R <num>"));
        vector.addElement(new Option("\tFull name of classifier to boost.\n\teg: weka.classifiers.bayes.NaiveBayes", "W", 1, "-W <class name>"));
        Enumeration enumeration = this.m_Classifier.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        this.setDebug(Utils.getFlag('D', stringArray));
        String string = Utils.getOption('B', stringArray);
        if (string.length() != 0) {
            this.setDiscretizeBin(Integer.parseInt(string));
        } else {
            this.setDiscretizeBin(0);
        }
        String string2 = Utils.getOption('R', stringArray);
        if (string2.length() != 0) {
            this.setMaxIterations(Integer.parseInt(string2));
        } else {
            this.setMaxIterations(10);
        }
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        vector.add("-R");
        vector.add("" + this.getMaxIterations());
        vector.add("-B");
        vector.add("" + this.getDiscretizeBin());
        String[] stringArray = super.getOptions();
        for (int i = 0; i < stringArray.length; ++i) {
            vector.add(stringArray[i]);
        }
        return vector.toArray(new String[vector.size()]);
    }

    public String maxIterationsTipText() {
        return "The maximum number of boost iterations.";
    }

    public void setMaxIterations(int n) {
        this.m_MaxIterations = n;
    }

    public int getMaxIterations() {
        return this.m_MaxIterations;
    }

    public String discretizeBinTipText() {
        return "The number of bins in discretization.";
    }

    public void setDiscretizeBin(int n) {
        this.m_DiscretizeBin = n;
    }

    public int getDiscretizeBin() {
        return this.m_DiscretizeBin;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        }
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_NumClasses = instances2.numClasses();
        this.m_NumIterations = this.m_MaxIterations;
        if (this.m_Classifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        if (!(this.m_Classifier instanceof WeightedInstancesHandler)) {
            throw new Exception("Base classifier cannot handle weighted instances!");
        }
        this.m_Models = Classifier.makeCopies(this.m_Classifier, this.getMaxIterations());
        if (this.m_Debug) {
            System.err.println("Base classifier: " + this.m_Classifier.getClass().getName());
        }
        this.m_Beta = new double[this.m_NumIterations];
        double d = instances2.numInstances();
        double d2 = 0.0;
        int n = 0;
        while ((double)n < d) {
            d2 += (double)instances2.instance(n).relationalValue(1).numInstances();
            ++n;
        }
        n = 0;
        while ((double)n < d) {
            instances2.instance(n).setWeight(d2 / d);
            ++n;
        }
        this.m_ConvertToSI.setInputFormat(instances2);
        Instances instances3 = Filter.useFilter(instances2, this.m_ConvertToSI);
        instances3.deleteAttributeAt(0);
        if (this.m_DiscretizeBin > 0) {
            this.m_Filter = new Discretize();
            this.m_Filter.setInputFormat(new Instances(instances3, 0));
            this.m_Filter.setBins(this.m_DiscretizeBin);
            instances3 = Filter.useFilter(instances3, this.m_Filter);
        }
        for (int i = 0; i < this.m_MaxIterations; ++i) {
            Instance instance;
            Object object;
            if (this.m_Debug) {
                System.err.println("\nIteration " + i);
            }
            this.m_Models[i].buildClassifier(instances3);
            double[] dArray = new double[(int)d];
            double[] dArray2 = new double[(int)d];
            boolean bl = true;
            boolean bl2 = true;
            int n2 = 0;
            int n3 = 0;
            while ((double)n3 < d) {
                object = instances2.instance(n3);
                double d3 = ((Instance)object).relationalValue(1).numInstances();
                int n4 = 0;
                while ((double)n4 < d3) {
                    Instance instance2;
                    if ((int)this.m_Models[i].classifyInstance(instance2 = instances3.instance(n2++)) != (int)((Instance)object).classValue()) {
                        int n5 = n3;
                        dArray[n5] = dArray[n5] + 1.0;
                    }
                    ++n4;
                }
                dArray2[n3] = ((Instance)object).weight();
                int n6 = n3;
                dArray[n6] = dArray[n6] / d3;
                if (dArray[n3] > 0.5) {
                    bl = false;
                }
                if (dArray[n3] < 0.5) {
                    bl2 = false;
                }
                ++n3;
            }
            if (bl || bl2) {
                this.m_Beta[i] = i == 0 ? 1.0 : 0.0;
                this.m_NumIterations = i + 1;
                if (!this.m_Debug) break;
                System.err.println("No errors");
                break;
            }
            double[] dArray3 = new double[]{0.0};
            object = new double[2][dArray3.length];
            object[0][0] = Double.NaN;
            object[1][0] = Double.NaN;
            OptEng optEng = new OptEng();
            optEng.setWeights(dArray2);
            optEng.setErrs(dArray);
            if (this.m_Debug) {
                System.out.println("Start searching for c... ");
            }
            dArray3 = optEng.findArgmin(dArray3, (double[][])object);
            while (dArray3 == null) {
                dArray3 = optEng.getVarbValues();
                if (this.m_Debug) {
                    System.out.println("200 iterations finished, not enough!");
                }
                dArray3 = optEng.findArgmin(dArray3, (double[][])object);
            }
            if (this.m_Debug) {
                System.out.println("Finished.");
            }
            this.m_Beta[i] = dArray3[0];
            if (this.m_Debug) {
                System.err.println("c = " + this.m_Beta[i]);
            }
            if (Double.isInfinite(this.m_Beta[i]) || Utils.smOrEq(this.m_Beta[i], 0.0)) {
                this.m_Beta[i] = i == 0 ? 1.0 : 0.0;
                this.m_NumIterations = i + 1;
                if (!this.m_Debug) break;
                System.err.println("Errors out of range!");
                break;
            }
            n2 = 0;
            double d4 = 0.0;
            int n7 = 0;
            while ((double)n7 < d) {
                instance = instances2.instance(n7);
                instance.setWeight(dArray2[n7] * Math.exp(this.m_Beta[i] * (2.0 * dArray[n7] - 1.0)));
                d4 += instance.weight();
                ++n7;
            }
            if (this.m_Debug) {
                System.err.println("Total weights = " + d4);
            }
            n7 = 0;
            while ((double)n7 < d) {
                instance = instances2.instance(n7);
                double d5 = instance.relationalValue(1).numInstances();
                instance.setWeight(d2 * instance.weight() / d4);
                int n8 = 0;
                while ((double)n8 < d5) {
                    Instance instance3 = instances3.instance(n2);
                    instance3.setWeight(instance.weight() / d5);
                    if (Double.isNaN(instance3.weight())) {
                        throw new Exception("instance " + n8 + " in bag " + n7 + " has weight NaN!");
                    }
                    ++n2;
                    ++n8;
                }
                ++n7;
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArray = new double[this.m_NumClasses];
        Instances instances = new Instances(instance.dataset(), 0);
        instances.add(instance);
        instances = Filter.useFilter(instances, this.m_ConvertToSI);
        instances.deleteAttributeAt(0);
        double d = instances.numInstances();
        if (this.m_DiscretizeBin > 0) {
            instances = Filter.useFilter(instances, this.m_Filter);
        }
        int n = 0;
        while ((double)n < d) {
            Instance instance2 = instances.instance(n);
            for (int i = 0; i < this.m_NumIterations; ++i) {
                int n2 = (int)this.m_Models[i].classifyInstance(instance2);
                dArray[n2] = dArray[n2] + this.m_Beta[i] / d;
            }
            ++n;
        }
        for (n = 0; n < dArray.length; ++n) {
            dArray[n] = Math.exp(dArray[n]);
        }
        Utils.normalize(dArray);
        return dArray;
    }

    public String toString() {
        if (this.m_Models == null) {
            return "No model built yet!";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("MIBoost: number of bins in discretization = " + this.m_DiscretizeBin + "\n");
        if (this.m_NumIterations == 0) {
            stringBuffer.append("No model built yet.\n");
        } else if (this.m_NumIterations == 1) {
            stringBuffer.append("No boosting possible, one classifier used: Weight = " + Utils.roundDouble(this.m_Beta[0], 2) + "\n");
            stringBuffer.append("Base classifiers:\n" + this.m_Models[0].toString());
        } else {
            stringBuffer.append("Base classifiers and their weights: \n");
            for (int i = 0; i < this.m_NumIterations; ++i) {
                stringBuffer.append("\n\n" + i + ": Weight = " + Utils.roundDouble(this.m_Beta[i], 2) + "\nBase classifier:\n" + this.m_Models[i].toString());
            }
        }
        stringBuffer.append("\n\nNumber of performed Iterations: " + this.m_NumIterations + "\n");
        return stringBuffer.toString();
    }

    public static void main(String[] stringArray) {
        MIBoost.runClassifier(new MIBoost(), stringArray);
    }

    private class OptEng
    extends Optimization {
        private double[] weights;
        private double[] errs;

        private OptEng() {
        }

        public void setWeights(double[] dArray) {
            this.weights = dArray;
        }

        public void setErrs(double[] dArray) {
            this.errs = dArray;
        }

        protected double objectiveFunction(double[] dArray) throws Exception {
            double d = 0.0;
            for (int i = 0; i < this.weights.length; ++i) {
                if (!Double.isNaN(d += this.weights[i] * Math.exp(dArray[0] * (2.0 * this.errs[i] - 1.0)))) continue;
                throw new Exception("Objective function value is NaN!");
            }
            return d;
        }

        protected double[] evaluateGradient(double[] dArray) throws Exception {
            double[] dArray2 = new double[1];
            for (int i = 0; i < this.weights.length; ++i) {
                dArray2[0] = dArray2[0] + this.weights[i] * (2.0 * this.errs[i] - 1.0) * Math.exp(dArray[0] * (2.0 * this.errs[i] - 1.0));
                if (!Double.isNaN(dArray2[0])) continue;
                throw new Exception("Gradient is NaN!");
            }
            return dArray2;
        }
    }
}

