/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.stream.IntStream;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;

public class IPAPassForwardFunctionCalls
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return true;
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        for (String fkey : fgraph.getReachableFunctions()) {
            FunctionOp call1;
            FunctionOp call2;
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            if (fstmt.getBody().size() != 1 || !IPAPassForwardFunctionCalls.singleFunctionOp(fstmt.getBody().get(0).getHops()) || !IPAPassForwardFunctionCalls.hasOnlySimpleArguments((FunctionOp)fstmt.getBody().get(0).getHops().get(0))) continue;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("IPA: Forward-function-call candidate L1: '" + fkey + "'"));
            }
            if (!IPAPassForwardFunctionCalls.hasConsistentOutputOrdering(fstmt, call2 = (FunctionOp)fstmt.getBody().get(0).getHops().get(0)) || fgraph.getFunctionCalls(fkey).size() > 1) continue;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("IPA: Forward-function-call candidate L2: '" + fkey + "'"));
            }
            if (!IPAPassForwardFunctionCalls.hasValidVariableNames(call1 = fgraph.getFunctionCalls(fkey).get(0)) || !IPAPassForwardFunctionCalls.hasValidVariableNames(call2) || !IPAPassForwardFunctionCalls.isFirstSubsetOfSecond(call2.getInputVariableNames(), call1.getInputVariableNames())) continue;
            call1.setFunctionName(call2.getFunctionName());
            call1.setFunctionNamespace(call2.getFunctionNamespace());
            IPAPassForwardFunctionCalls.reconcileFunctionInputsInPlace(call1, call2);
            fgraph.replaceFunctionCalls(fkey, call2.getFunctionKey());
            if (!fgraph.containsSecondOrderCall()) {
                prog.removeFunctionStatementBlock(fkey);
            }
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)("IPA: Forward-function-call: replaced '" + fkey + "' with '" + call2.getFunctionKey() + "'"));
        }
        return false;
    }

    private static boolean singleFunctionOp(ArrayList<Hop> hops) {
        if (hops == null || hops.isEmpty() || hops.size() != 1) {
            return false;
        }
        return hops.get(0) instanceof FunctionOp;
    }

    private static boolean hasOnlySimpleArguments(FunctionOp fop) {
        return fop.getInput().stream().allMatch(h -> h instanceof LiteralOp || HopRewriteUtils.isData(h, Types.OpOpData.TRANSIENTREAD));
    }

    private static boolean hasConsistentOutputOrdering(FunctionStatement fstmt, FunctionOp fop2) {
        int len = Math.min(fstmt.getOutputParams().size(), fop2.getOutputVariableNames().length);
        return IntStream.range(0, len).allMatch(i -> fstmt.getOutputParams().get(i).getName().equals(fop2.getOutputVariableNames()[i]));
    }

    private static boolean hasValidVariableNames(FunctionOp fop) {
        return fop.getInputVariableNames() != null && Arrays.stream(fop.getInputVariableNames()).allMatch(s -> s != null);
    }

    private static boolean isFirstSubsetOfSecond(String[] first, String[] second) {
        HashSet<String> probe = new HashSet<String>();
        for (String s2 : second) {
            probe.add(s2);
        }
        return Arrays.stream(first).allMatch(s -> probe.contains(s));
    }

    private static void reconcileFunctionInputsInPlace(FunctionOp call1, FunctionOp call2) {
        HashMap<String, Hop> probe = new HashMap<String, Hop>();
        for (int i = 0; i < call2.getInput().size(); ++i) {
            probe.put(call2.getInputVariableNames()[i], call2.getInput().get(i));
        }
        ArrayList<String> varNames = new ArrayList<String>();
        ArrayList<Hop> inputs = new ArrayList<Hop>();
        for (int i = 0; i < call1.getInput().size(); ++i) {
            if (!probe.containsKey(call1.getInputVariableNames()[i])) continue;
            varNames.add(call1.getInputVariableNames()[i]);
            inputs.add(probe.get(call1.getInputVariableNames()[i]) instanceof LiteralOp ? (Hop)probe.get(call1.getInputVariableNames()[i]) : call1.getInput().get(i));
        }
        HopRewriteUtils.removeAllChildReferences(call1);
        call1.addAllInputs(inputs);
        call1.setInputVariableNames(varNames.toArray(new String[0]));
    }
}

