/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.core.Capabilities;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

public class MIEMDD
extends RandomizableClassifier
implements OptionHandler,
MultiInstanceCapabilitiesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = 3899547154866223734L;
    protected int m_ClassIndex;
    protected double[] m_Par;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected double[][][] m_Data;
    protected Instances m_Attributes;
    protected double[][] m_emData;
    protected Filter m_Filter = null;
    protected int m_filterType = 1;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = new Tag[]{new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();

    public String globalInfo() {
        return "EMDD model builds heavily upon Dietterich's Diverse Density (DD) algorithm.\nIt is a general framework for MI learning of converting the MI problem to a single-instance setting using EM. In this implementation, we use most-likely cause DD model and only use 3 random selected postive bags as initial starting points of EM.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Qi Zhang and Sally A. Goldman");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "EM-DD: An Improved Multiple-Instance Learning Technique");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Neural Information Processing Systems 14");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "1073-108");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "MIT Press");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither.\n\t(default 1=standardize)", "N", 1, "-N <num>"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('N', stringArray);
        if (string.length() != 0) {
            this.setFilterType(new SelectedTag(Integer.parseInt(string), TAGS_FILTER));
        } else {
            this.setFilterType(new SelectedTag(1, TAGS_FILTER));
        }
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        String[] stringArray = super.getOptions();
        for (int i = 0; i < stringArray.length; ++i) {
            vector.add(stringArray[i]);
        }
        vector.add("-N");
        vector.add("" + this.m_filterType);
        return vector.toArray(new String[vector.size()]);
    }

    public String filterTypeTipText() {
        return "The filter type for transforming the training data.";
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public void setFilterType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_FILTER) {
            this.m_filterType = selectedTag.getSelectedTag().getID();
        }
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int n;
        int n2;
        int n3;
        int n4;
        int n5;
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_ClassIndex = instances.classIndex();
        this.m_NumClasses = instances.numClasses();
        int n6 = instances.attribute(1).relation().numAttributes();
        int n7 = instances.numInstances();
        int[] nArray = new int[n7];
        Instances instances2 = new Instances(instances.attribute(1).relation(), 0);
        this.m_Data = new double[n7][n6][];
        this.m_Classes = new int[n7];
        this.m_Attributes = instances2.stringFreeStructure();
        if (this.m_Debug) {
            System.out.println("\n\nExtracting data...");
        }
        for (n5 = 0; n5 < n7; ++n5) {
            Instance instance = instances.instance(n5);
            this.m_Classes[n5] = (int)instance.classValue();
            Instances instances3 = instance.relationalValue(1);
            for (n4 = 0; n4 < instances3.numInstances(); ++n4) {
                Instance instance2 = instances3.instance(n4);
                instances2.add(instance2);
            }
            nArray[n5] = n4 = instances3.numInstances();
        }
        this.m_Filter = this.m_filterType == 1 ? new Standardize() : (this.m_filterType == 0 ? new Normalize() : null);
        if (this.m_Filter != null) {
            this.m_Filter.setInputFormat(instances2);
            instances2 = Filter.useFilter(instances2, this.m_Filter);
        }
        this.m_Missing.setInputFormat(instances2);
        instances2 = Filter.useFilter(instances2, this.m_Missing);
        n5 = 0;
        int n8 = 0;
        for (int i = 0; i < n7; ++i) {
            for (n4 = 0; n4 < instances2.numAttributes(); ++n4) {
                this.m_Data[i][n4] = new double[nArray[i]];
                n5 = n8;
                for (int j = 0; j < nArray[i]; ++j) {
                    this.m_Data[i][n4][j] = instances2.instance(n5).value(n4);
                    ++n5;
                }
            }
            n8 = n5;
        }
        if (this.m_Debug) {
            System.out.println("\n\nIteration History...");
        }
        this.m_emData = new double[n7][n6];
        this.m_Par = new double[2 * n6];
        double[] dArray = new double[n6 * 2];
        double[] dArray2 = new double[dArray.length];
        double[] dArray3 = new double[dArray.length];
        double[] dArray4 = new double[dArray.length];
        double[][] dArray5 = new double[2][dArray.length];
        double d = Double.MAX_VALUE;
        double d2 = Double.MAX_VALUE;
        for (int i = 0; i < dArray.length; ++i) {
            dArray5[0][i] = Double.NaN;
            dArray5[1][i] = Double.NaN;
        }
        Random random = new Random(this.getSeed());
        FastVector fastVector = new FastVector();
        while (this.m_Classes[n3 = random.nextInt(n7 - 1)] == 0) {
        }
        fastVector.addElement(new Integer(n3));
        while ((n2 = random.nextInt(n7 - 1)) == n3 || this.m_Classes[n2] == 0) {
        }
        fastVector.addElement(new Integer(n2));
        while ((n = random.nextInt(n7 - 1)) == n3 || n == n2 || this.m_Classes[n] == 0) {
        }
        fastVector.addElement(new Integer(n));
        for (int i = 0; i < fastVector.size(); ++i) {
            int n9 = (Integer)fastVector.elementAt(i);
            if (this.m_Debug) {
                System.out.println("\nH0 at " + n9);
            }
            for (int j = 0; j < this.m_Data[n9][0].length; ++j) {
                int n10;
                int n11;
                int n12;
                for (n12 = 0; n12 < n6; ++n12) {
                    dArray[2 * n12] = this.m_Data[n9][n12][j];
                    dArray[2 * n12 + 1] = 1.0;
                }
                double d3 = Double.MAX_VALUE;
                double d4 = 1.7976931348623158E307;
                int n13 = 0;
                while (d4 < d3 && n13 < 10) {
                    ++n13;
                    d3 = d4;
                    if (this.m_Debug) {
                        System.out.println("\niteration: " + n13);
                    }
                    for (n12 = 0; n12 < this.m_Data.length; ++n12) {
                        n11 = this.findInstance(n12, dArray);
                        for (n10 = 0; n10 < this.m_Data[0].length; ++n10) {
                            this.m_emData[n12][n10] = this.m_Data[n12][n10][n11];
                        }
                    }
                    if (this.m_Debug) {
                        System.out.println("E-step for new H' finished");
                    }
                    OptEng optEng = new OptEng();
                    dArray2 = optEng.findArgmin(dArray, dArray5);
                    while (dArray2 == null) {
                        dArray2 = optEng.getVarbValues();
                        if (this.m_Debug) {
                            System.out.println("200 iterations finished, not enough!");
                        }
                        dArray2 = optEng.findArgmin(dArray2, dArray5);
                    }
                    d4 = optEng.getMinFunction();
                    dArray3 = dArray;
                    dArray = dArray2;
                }
                double[] dArray6 = new double[2];
                n11 = 0;
                this.m_Par = d4 > d3 ? dArray3 : dArray;
                for (n10 = 0; n10 < instances.numInstances(); ++n10) {
                    dArray6 = this.distributionForInstance(instances.instance(n10));
                    if (dArray6[1] >= 0.5 && this.m_Classes[n10] == 0) {
                        ++n11;
                        continue;
                    }
                    if (!(dArray6[1] < 0.5) || this.m_Classes[n10] != 1) continue;
                    ++n11;
                }
                if (!((double)n11 < d2)) continue;
                dArray4 = this.m_Par;
                d2 = n11;
                d = d4 > d3 ? d3 : d4;
                if (!this.m_Debug) continue;
                System.out.println("error= " + n11 + "  nll= " + d);
            }
            if (!this.m_Debug) continue;
            System.out.println(n9 + ":  -------------<Converged>--------------");
            System.out.println("current minimum error= " + d2 + "  nll= " + d);
        }
        this.m_Par = dArray4;
    }

    protected int findInstance(int n, double[] dArray) {
        double d = Double.MAX_VALUE;
        int n2 = 0;
        int n3 = this.m_Data[n][0].length;
        for (int i = 0; i < n3; ++i) {
            double d2 = 0.0;
            for (int j = 0; j < this.m_Data[n].length; ++j) {
                d2 += (this.m_Data[n][j][i] - dArray[j * 2]) * (this.m_Data[n][j][i] - dArray[j * 2]) * dArray[j * 2 + 1] * dArray[j * 2 + 1];
            }
            if (!(d2 < d)) continue;
            d = d2;
            n2 = i;
        }
        return n2;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        Instances instances = instance.relationalValue(1);
        if (this.m_Filter != null) {
            instances = Filter.useFilter(instances, this.m_Filter);
        }
        instances = Filter.useFilter(instances, this.m_Missing);
        int n = instances.numInstances();
        int n2 = instances.numAttributes();
        double[][] dArray = new double[n][n2];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n2; ++j) {
                dArray[i][j] = instances.instance(i).value(j);
            }
        }
        double d = Double.MAX_VALUE;
        double d2 = -1.0;
        for (int i = 0; i < n; ++i) {
            double d3 = 0.0;
            for (int j = 0; j < n2; ++j) {
                d3 += (dArray[i][j] - this.m_Par[j * 2]) * (dArray[i][j] - this.m_Par[j * 2]) * this.m_Par[j * 2 + 1] * this.m_Par[j * 2 + 1];
            }
            if (!(d3 < d)) continue;
            d = d3;
            d2 = Math.exp(-d3);
        }
        double[] dArray2 = new double[2];
        dArray2[1] = d2;
        dArray2[0] = 1.0 - dArray2[1];
        return dArray2;
    }

    public String toString() {
        String string = "MIEMDD";
        if (this.m_Par == null) {
            return string + ": No model built yet.";
        }
        string = string + "\nCoefficients...\nVariable       Point       Scale\n";
        int n = 0;
        int n2 = 0;
        while (n < this.m_Par.length / 2) {
            string = string + this.m_Attributes.attribute(n2).name();
            string = string + " " + Utils.doubleToString(this.m_Par[n * 2], 12, 4);
            string = string + " " + Utils.doubleToString(this.m_Par[n * 2 + 1], 12, 4) + "\n";
            ++n;
            ++n2;
        }
        return string;
    }

    public static void main(String[] stringArray) {
        MIEMDD.runClassifier(new MIEMDD(), stringArray);
    }

    private class OptEng
    extends Optimization {
        private OptEng() {
        }

        protected double objectiveFunction(double[] dArray) {
            double d = 0.0;
            for (int i = 0; i < MIEMDD.this.m_Classes.length; ++i) {
                double d2 = 0.0;
                for (int j = 0; j < MIEMDD.this.m_emData[i].length; ++j) {
                    d2 += (MIEMDD.this.m_emData[i][j] - dArray[j * 2]) * (MIEMDD.this.m_emData[i][j] - dArray[j * 2]) * dArray[j * 2 + 1] * dArray[j * 2 + 1];
                }
                d2 = Math.exp(-d2);
                if (MIEMDD.this.m_Classes[i] == 1) {
                    if (d2 <= m_Zero) {
                        d2 = m_Zero;
                    }
                    d -= Math.log(d2);
                    continue;
                }
                if ((d2 = 1.0 - d2) <= m_Zero) {
                    d2 = m_Zero;
                }
                d -= Math.log(d2);
            }
            return d;
        }

        protected double[] evaluateGradient(double[] dArray) {
            double[] dArray2 = new double[dArray.length];
            for (int i = 0; i < MIEMDD.this.m_Classes.length; ++i) {
                int n;
                double[] dArray3 = new double[dArray.length];
                double d = 0.0;
                for (n = 0; n < MIEMDD.this.m_emData[i].length; ++n) {
                    d += (MIEMDD.this.m_emData[i][n] - dArray[n * 2]) * (MIEMDD.this.m_emData[i][n] - dArray[n * 2]) * dArray[n * 2 + 1] * dArray[n * 2 + 1];
                }
                d = Math.exp(-d);
                for (n = 0; n < MIEMDD.this.m_emData[i].length; ++n) {
                    dArray3[2 * n] = 2.0 * (dArray[2 * n] - MIEMDD.this.m_emData[i][n]) * dArray[n * 2 + 1] * dArray[n * 2 + 1];
                    dArray3[2 * n + 1] = 2.0 * (dArray[2 * n] - MIEMDD.this.m_emData[i][n]) * (dArray[2 * n] - MIEMDD.this.m_emData[i][n]) * dArray[n * 2 + 1];
                }
                for (n = 0; n < MIEMDD.this.m_emData[i].length; ++n) {
                    if (MIEMDD.this.m_Classes[i] == 1) {
                        int n2 = 2 * n;
                        dArray2[n2] = dArray2[n2] + dArray3[2 * n];
                        int n3 = 2 * n + 1;
                        dArray2[n3] = dArray2[n3] + dArray3[2 * n + 1];
                        continue;
                    }
                    int n4 = 2 * n;
                    dArray2[n4] = dArray2[n4] - dArray3[2 * n] * d / (1.0 - d);
                    int n5 = 2 * n + 1;
                    dArray2[n5] = dArray2[n5] - dArray3[2 * n + 1] * d / (1.0 - d);
                }
            }
            return dArray2;
        }
    }
}

