/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLCompressionException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.BitmapEncoder;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC1;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC2;
import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE;
import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorExact;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class ColGroupFactory {
    public static ColGroup[] compressColGroups(MatrixBlock in, HashMap<Integer, Double> compRatios, List<int[]> groups, CompressionSettings compSettings, int k) {
        if (k <= 1) {
            return ColGroupFactory.compressColGroups(in, compRatios, groups, compSettings);
        }
        try {
            ExecutorService pool = CommonThreadPool.get(k);
            ArrayList<CompressTask> tasks = new ArrayList<CompressTask>();
            for (int[] colIndexes : groups) {
                tasks.add(new CompressTask(in, compRatios, colIndexes, compSettings));
            }
            List rtask = pool.invokeAll(tasks);
            ArrayList ret = new ArrayList();
            for (Future lrtask : rtask) {
                ret.add(lrtask.get());
            }
            pool.shutdown();
            return ret.toArray(new ColGroup[0]);
        }
        catch (InterruptedException | ExecutionException e) {
            return ColGroupFactory.compressColGroups(in, compRatios, groups, compSettings);
        }
    }

    private static ColGroup[] compressColGroups(MatrixBlock in, HashMap<Integer, Double> compRatios, List<int[]> groups, CompressionSettings compSettings) {
        ColGroup[] ret = new ColGroup[groups.size()];
        for (int i = 0; i < groups.size(); ++i) {
            ret[i] = ColGroupFactory.compressColGroup(in, compRatios, groups.get(i), compSettings);
        }
        return ret;
    }

    private static ColGroup compressColGroup(MatrixBlock in, HashMap<Integer, Double> compRatios, int[] colIndexes, CompressionSettings compSettings) {
        int[] allGroupIndices = (int[])colIndexes.clone();
        ABitmap ubm = null;
        PriorityQueue<CompressedColumn> compRatioPQ = CompressedColumn.makePriorityQue(compRatios, colIndexes);
        CompressedSizeEstimatorExact estimator = new CompressedSizeEstimatorExact(in, compSettings);
        block0: while (true) {
            CompressedSizeInfoColGroup sizeInfo;
            if ((sizeInfo = new CompressedSizeInfoColGroup(estimator.estimateCompressedColGroupSize(ubm = BitmapEncoder.extractBitmap(colIndexes, in, compSettings)), compSettings.validCompressions)).getMinSize() == 0L) {
                throw new DMLRuntimeException("Size info of compressed Col Group is 0");
            }
            double compRatio = sizeInfo.getCompressionSize(ColGroup.CompressionType.UNCOMPRESSED) / sizeInfo.getMinSize();
            if (compRatio > 1.0) {
                int rlen = compSettings.transposeInput ? in.getNumColumns() : in.getNumRows();
                return ColGroupFactory.compress(colIndexes, rlen, ubm, sizeInfo.getBestCompressionType(), compSettings, in);
            }
            allGroupIndices[compRatioPQ.poll().colIx] = -1;
            if (colIndexes.length - 1 == 0) {
                return null;
            }
            colIndexes = new int[colIndexes.length - 1];
            int ix = 0;
            int[] nArray = allGroupIndices;
            int n = nArray.length;
            int n2 = 0;
            while (true) {
                if (n2 >= n) continue block0;
                int col = nArray[n2];
                if (col != -1) {
                    colIndexes[ix++] = col;
                }
                ++n2;
            }
            break;
        }
    }

    public static ColGroup compress(int[] colIndexes, int rlen, ABitmap ubm, ColGroup.CompressionType compType, CompressionSettings cs, MatrixBlock rawMatrixBlock) {
        switch (compType) {
            case DDC: {
                if (ubm.getNumValues() < 256) {
                    return new ColGroupDDC1(colIndexes, rlen, ubm, cs);
                }
                return new ColGroupDDC2(colIndexes, rlen, ubm, cs);
            }
            case RLE: {
                return new ColGroupRLE(colIndexes, rlen, ubm, cs);
            }
            case OLE: {
                return new ColGroupOLE(colIndexes, rlen, ubm, cs);
            }
            case UNCOMPRESSED: {
                return new ColGroupUncompressed(colIndexes, rawMatrixBlock, cs);
            }
        }
        throw new DMLCompressionException("Not implemented ColGroup Type compressed in factory.");
    }

    public static List<ColGroup> assignColumns(int numCols, ColGroup[] colGroups, MatrixBlock rawBlock, CompressionSettings compSettings) {
        ArrayList<ColGroup> _colGroups = new ArrayList<ColGroup>();
        HashSet<Integer> remainingCols = ColGroupFactory.seq(0, numCols - 1, 1);
        for (int j = 0; j < colGroups.length; ++j) {
            if (colGroups[j] == null) continue;
            for (int col : colGroups[j].getColIndices()) {
                remainingCols.remove(col);
            }
            _colGroups.add(colGroups[j]);
        }
        if (!remainingCols.isEmpty()) {
            int[] list = remainingCols.stream().mapToInt(i -> i).toArray();
            ColGroupUncompressed ucgroup = new ColGroupUncompressed(list, rawBlock, compSettings);
            _colGroups.add(ucgroup);
        }
        return _colGroups;
    }

    private static HashSet<Integer> seq(int from, int to, int incr) {
        HashSet<Integer> ret = new HashSet<Integer>();
        for (int i = from; i <= to; i += incr) {
            ret.add(i);
        }
        return ret;
    }

    private static class CompressTask
    implements Callable<ColGroup> {
        private final MatrixBlock _in;
        private final HashMap<Integer, Double> _compRatios;
        private final int[] _colIndexes;
        private final CompressionSettings _compSettings;

        protected CompressTask(MatrixBlock in, HashMap<Integer, Double> compRatios, int[] colIndexes, CompressionSettings compSettings) {
            this._in = in;
            this._compRatios = compRatios;
            this._colIndexes = colIndexes;
            this._compSettings = compSettings;
        }

        @Override
        public ColGroup call() {
            return ColGroupFactory.compressColGroup(this._in, this._compRatios, this._colIndexes, this._compSettings);
        }
    }

    private static class CompressedColumn
    implements Comparable<CompressedColumn> {
        final int colIx;
        final double compRatio;

        public CompressedColumn(int colIx, double compRatio) {
            this.colIx = colIx;
            this.compRatio = compRatio;
        }

        public static PriorityQueue<CompressedColumn> makePriorityQue(HashMap<Integer, Double> compRatios, int[] colIndexes) {
            PriorityQueue<CompressedColumn> compRatioPQ = new PriorityQueue<CompressedColumn>();
            for (int i = 0; i < colIndexes.length; ++i) {
                compRatioPQ.add(new CompressedColumn(i, compRatios.get(colIndexes[i])));
            }
            return compRatioPQ;
        }

        @Override
        public int compareTo(CompressedColumn o) {
            return (int)Math.signum(this.compRatio - o.compRatio);
        }
    }
}

