/*
 * Decompiled with CFR 0.152.
 */
package eu.kliegr.ac1.rule.extend;

import eu.kliegr.ac1.data.Attribute;
import eu.kliegr.ac1.rule.Consequent;
import eu.kliegr.ac1.rule.Prediction;
import eu.kliegr.ac1.rule.RuleQuality;
import eu.kliegr.ac1.rule.extend.ArrayIndexComparator;
import eu.kliegr.ac1.rule.extend.AttributeValueAnnotation;
import eu.kliegr.ac1.rule.extend.Distribution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.logging.Level;
import java.util.logging.Logger;

public class DistributionFactory {
    private static final Logger LOGGER = Logger.getLogger(DistributionFactory.class.getName());
    private ArrayList<Consequent> consequents;
    private Float[] equalWeights;

    private void init(Set<Consequent> cons) {
        this.consequents = new ArrayList<Consequent>(cons);
        this.equalWeights = new Float[cons.size()];
        for (int i = 0; i < cons.size(); ++i) {
            this.equalWeights[i] = new Float(1.0 / (double)cons.size());
        }
    }

    public Prediction getMax(Distribution finalDistr) {
        float max = Float.MIN_VALUE;
        int maxIndex = -1;
        float[] probs = finalDistr.getProbs();
        for (int i = 0; i < finalDistr.getProbs().length; ++i) {
            if (!(probs[i] > max)) continue;
            max = probs[i];
            maxIndex = i;
        }
        if (max == Float.MIN_VALUE) {
            LOGGER.severe("Distribution has all probabilities set to zero. Picking class at random to avoid failure");
            maxIndex = ThreadLocalRandom.current().nextInt(0, this.consequents.size());
        }
        Prediction p = new Prediction(this.consequents.get(maxIndex), max);
        return p;
    }

    public Prediction[] getMax(Distribution finalDistr, int topn) {
        Prediction[] result;
        if (topn == 1) {
            result = new Prediction[]{this.getMax(finalDistr)};
        } else {
            float[] probs = finalDistr.getProbs();
            ArrayIndexComparator comparator = new ArrayIndexComparator(probs);
            Integer[] indexes = comparator.createIndexArray();
            Arrays.sort(indexes, comparator);
            int length = topn >= indexes.length ? indexes.length : topn;
            result = new Prediction[length];
            for (int i = 0; i < length; ++i) {
                int index = indexes[probs.length - i - 1];
                Consequent cons = this.consequents.get(index);
                float trust = probs[index];
                result[i] = new Prediction(cons, trust);
            }
        }
        return result;
    }

    public Distribution aggregateDistributions(HashMap<Attribute, ArrayList<Distribution>> distributionsByAttribute) {
        ArrayList<Distribution> ah_agg = new ArrayList<Distribution>();
        distributionsByAttribute.entrySet().stream().map(distForAtt -> this.aggregate((ArrayList)distForAtt.getValue())).forEach(aggDistForAtt -> ah_agg.add((Distribution)aggDistForAtt));
        Distribution finalDistr = this.aggregate(ah_agg);
        return finalDistr;
    }

    public Distribution aggregate(ArrayList<Distribution> distribs) {
        if (distribs.isEmpty()) {
            if (LOGGER.isLoggable(Level.FINE)) {
                LOGGER.fine("Nothing to aggregate");
            }
            return null;
        }
        float weightSum = distribs.stream().map(d -> Float.valueOf(d.getWeight())).reduce((val1, val2) -> Float.valueOf(val1.floatValue() + val2.floatValue())).get().floatValue();
        Float[] weightsNormalized = (Float[])distribs.stream().map(d -> Float.valueOf(d.getWeight())).map(weight -> Float.valueOf(weight.floatValue() / weightSum)).toArray(Float[]::new);
        float[] distrib = new float[this.consequents.size()];
        for (int i = 0; i < this.consequents.size(); ++i) {
            Consequent con = this.consequents.get(i);
            float totalconf = 0.0f;
            for (int j = 0; j < distribs.size(); ++j) {
                totalconf += distribs.get(j).getProbs()[i] * weightsNormalized[j].floatValue();
            }
            distrib[i] = totalconf;
        }
        Distribution result = new Distribution(distrib);
        return result;
    }

    public Distribution convert(AttributeValueAnnotation annot) {
        if (this.consequents == null) {
            if (LOGGER.isLoggable(Level.FINE)) {
                LOGGER.fine("Distribution class init: caching consequents order and count");
            }
            this.init(annot.getDistribution().keySet());
        }
        float[] distrib = new float[this.consequents.size()];
        for (int i = 0; i < this.consequents.size(); ++i) {
            Consequent con = this.consequents.get(i);
            RuleQuality rq = annot.getDistribution().get(con);
            if (rq == null) {
                LOGGER.log(Level.SEVERE, "No distribution associated with value:{0} and consequent{1}", new Object[]{annot.getValue(), con});
            }
            distrib[i] = rq.getConfidence();
        }
        Distribution result = new Distribution(distrib);
        return result;
    }
}

