/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.FunctionCallCP;
import org.apache.sysml.lops.FunctionCallCPSingle;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.parser.DMLProgram;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimatorHops;

public class FunctionOp
extends Hop {
    public static final String OPSTRING = "extfunct";
    private FunctionType _type = null;
    private String _fnamespace = null;
    private String _fname = null;
    private String[] _inputNames = null;
    private String[] _outputNames = null;
    private ArrayList<Hop> _outputHops = null;
    private boolean _singleOutFun = false;

    private FunctionOp() {
    }

    public FunctionOp(FunctionType type, String fnamespace, String fname, String[] inputNames, List<Hop> inputs, String[] outputNames, ArrayList<Hop> outputHops) {
        this(type, fnamespace, fname, inputNames, inputs, outputNames, false);
        this._outputHops = outputHops;
    }

    public FunctionOp(FunctionType type, String fnamespace, String fname, String[] inputNames, List<Hop> inputs, String[] outputNames, boolean singleOut) {
        super(fnamespace + "::" + fname, Expression.DataType.UNKNOWN, Expression.ValueType.UNKNOWN);
        this._type = type;
        this._fnamespace = fnamespace;
        this._fname = fname;
        this._inputNames = inputNames;
        this._outputNames = outputNames;
        this._singleOutFun = singleOut;
        for (Hop in : inputs) {
            this.getInput().add(in);
            in.getParent().add(this);
        }
    }

    @Override
    public void checkArity() {
    }

    public String getFunctionKey() {
        return DMLProgram.constructFunctionKey(this.getFunctionNamespace(), this.getFunctionName());
    }

    public String getFunctionNamespace() {
        return this._fnamespace;
    }

    public String getFunctionName() {
        return this._fname;
    }

    public void setFunctionName(String fname) {
        this._fname = fname;
    }

    public ArrayList<Hop> getOutputs() {
        return this._outputHops;
    }

    public String[] getInputVariableNames() {
        return this._inputNames;
    }

    public String[] getOutputVariableNames() {
        return this._outputNames;
    }

    public FunctionType getFunctionType() {
        return this._type;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return false;
    }

    @Override
    public void computeMemEstimate(MemoTable memo) {
        if (this._type == FunctionType.DML) {
            this._memEstimate = 1.0;
        } else if (this._type == FunctionType.EXTERNAL_MEM) {
            this._memEstimate = 2.0 * this.getInputSize();
        } else if (this._type == FunctionType.EXTERNAL_FILE || this._type == FunctionType.UNKNOWN) {
            this._memEstimate = CostEstimatorHops.DEFAULT_MEM_MR;
        } else if (this._type == FunctionType.MULTIRETURN_BUILTIN) {
            boolean outputDimsKnown = true;
            for (Hop out : this.getOutputs()) {
                outputDimsKnown &= out.dimsKnown();
            }
            if (outputDimsKnown) {
                long lnnz = this._nnz >= 0L ? this._nnz : this._dim1 * this._dim2;
                this._outputMemEstimate = this.computeOutputMemEstimate(this._dim1, this._dim2, lnnz);
                this._processingMemEstimate = this.computeIntermediateMemEstimate(this._dim1, this._dim2, lnnz);
            }
            this._memEstimate = this.getInputOutputSize();
        }
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        if (this.getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) {
            throw new RuntimeException("Invalid call of computeOutputMemEstimate in FunctionOp.");
        }
        if (this.getFunctionName().equalsIgnoreCase("qr")) {
            long outputH = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 0.5);
            long outputR = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), this.getOutputs().get(1).getDim2(), 0.5);
            return outputH + outputR;
        }
        if (this.getFunctionName().equalsIgnoreCase("lu")) {
            long outputP = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), this.getOutputs().get(1).getDim2(), 1.0 / (double)this.getOutputs().get(1).getDim2());
            long outputL = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 0.5);
            long outputU = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), this.getOutputs().get(1).getDim2(), 0.5);
            return outputL + outputU + outputP;
        }
        if (this.getFunctionName().equalsIgnoreCase("eigen")) {
            long outputVectors = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 1.0);
            long outputValues = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), 1L, 1.0);
            return outputVectors + outputValues;
        }
        if (this.getFunctionName().equalsIgnoreCase("lstm") || this.getFunctionName().equalsIgnoreCase("lstm_backward")) {
            return 0.0;
        }
        if (this.getFunctionName().equalsIgnoreCase("batch_norm2d") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
            return OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), this.getOutputs().get(1).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(2).getDim1(), this.getOutputs().get(2).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(3).getDim1(), this.getOutputs().get(3).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(4).getDim1(), this.getOutputs().get(4).getDim2(), 1.0);
        }
        if (this.getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
            return OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 1.0);
        }
        if (this.getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) {
            return OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), this.getOutputs().get(1).getDim2(), 1.0) + OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(2).getDim1(), this.getOutputs().get(2).getDim2(), 1.0);
        }
        if (this.getFunctionName().equalsIgnoreCase("svd")) {
            long outputU = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(0).getDim1(), this.getOutputs().get(0).getDim2(), 1.0);
            long outputSigma = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(1).getDim1(), this.getOutputs().get(1).getDim2(), 1.0);
            long outputV = OptimizerUtils.estimateSizeExactSparsity(this.getOutputs().get(2).getDim1(), this.getOutputs().get(2).getDim2(), 1.0);
            return outputU + outputSigma + outputV;
        }
        throw new RuntimeException("Invalid call of computeOutputMemEstimate in FunctionOp.");
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        if (this.getFunctionType() != FunctionType.MULTIRETURN_BUILTIN) {
            throw new RuntimeException("Invalid call of computeIntermediateMemEstimate in FunctionOp.");
        }
        if (this.getFunctionName().equalsIgnoreCase("qr")) {
            return OptimizerUtils.estimateSizeExactSparsity(this.getInput().get(0).getDim1(), this.getInput().get(0).getDim2(), 1.0);
        }
        if (this.getFunctionName().equalsIgnoreCase("lu")) {
            return OptimizerUtils.estimateSizeExactSparsity(this.getInput().get(0).getDim1(), 1L, 1.0);
        }
        if (this.getFunctionName().equalsIgnoreCase("eigen")) {
            return OptimizerUtils.estimateSizeExactSparsity(this.getInput().get(0).getDim1(), this.getInput().get(0).getDim2(), 1.0) + 3L * OptimizerUtils.estimateSizeExactSparsity(this.getInput().get(0).getDim1(), 1L, 1.0);
        }
        if (this.getFunctionName().equalsIgnoreCase("batch_norm2d") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_train") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
            return 0.0;
        }
        if (this.getFunctionName().equalsIgnoreCase("lstm") || this.getFunctionName().equalsIgnoreCase("lstm_backward")) {
            return 0.0;
        }
        if (this.getFunctionName().equalsIgnoreCase("svd")) {
            double interOutput = OptimizerUtils.estimateSizeExactSparsity(1L, this.getInput().get(0).getDim2(), 1.0);
            return interOutput;
        }
        throw new RuntimeException("Invalid call of computeIntermediateMemEstimate in FunctionOp.");
    }

    @Override
    protected long[] inferOutputCharacteristics(MemoTable memo) {
        throw new RuntimeException("Invalid call of inferOutputCharacteristics in FunctionOp.");
    }

    @Override
    public boolean isGPUEnabled() {
        return this.getFunctionName().equalsIgnoreCase("lstm") || this.getFunctionName().equalsIgnoreCase("lstm_backward") || this.getFunctionName().equalsIgnoreCase("batch_norm2d") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_backward") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_train") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_test");
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        LopProperties.ExecType et = this.optFindExecType();
        ArrayList<Lop> tmp = new ArrayList<Lop>();
        for (Hop in : this.getInput()) {
            tmp.add(in.constructLops());
        }
        Lop fcall = this._singleOutFun ? new FunctionCallCPSingle(tmp, this._fnamespace, this._fname, et) : new FunctionCallCP(tmp, this._fnamespace, this._fname, this._inputNames, this._outputNames, this._outputHops, et);
        this.setLineNumbers(fcall);
        this.setLops(fcall);
        return this.getLops();
    }

    @Override
    public String getOpString() {
        return OPSTRING;
    }

    @Override
    protected LopProperties.ExecType optFindExecType() {
        this.checkAndSetForcedPlatform();
        if (this.getFunctionType() == FunctionType.MULTIRETURN_BUILTIN) {
            boolean isBuiltinFunction = this.isBuiltinFunction();
            if (isBuiltinFunction && this.getFunctionName().equalsIgnoreCase("transformencode")) {
                this._etype = this._etypeForced == LopProperties.ExecType.SPARK || this.getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.CP;
            } else if (isBuiltinFunction && (this.getFunctionName().equalsIgnoreCase("lstm") || this.getFunctionName().equalsIgnoreCase("lstm_backward"))) {
                if (!DMLScript.USE_ACCELERATOR) {
                    throw new RuntimeException("The function " + this.getFunctionName() + " is only supported on GPU.");
                }
                this._etype = LopProperties.ExecType.GPU;
            } else {
                this._etype = isBuiltinFunction && (this.getFunctionName().equalsIgnoreCase("batch_norm2d") || this.getFunctionName().equalsIgnoreCase("batch_norm2d_backward")) ? (DMLScript.USE_ACCELERATOR ? LopProperties.ExecType.GPU : LopProperties.ExecType.CP) : (isBuiltinFunction && this.getFunctionName().equalsIgnoreCase("batch_norm2d_train") ? LopProperties.ExecType.GPU : LopProperties.ExecType.CP);
            }
        } else {
            this._etype = LopProperties.ExecType.CP;
        }
        return this._etype;
    }

    private boolean isBuiltinFunction() {
        return this.getFunctionNamespace().equals("_internal");
    }

    @Override
    public void refreshSizeInformation() {
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        FunctionOp ret = new FunctionOp();
        ret.clone(this, false);
        ret._type = this._type;
        ret._fnamespace = this._fnamespace;
        ret._fname = this._fname;
        ret._inputNames = this._inputNames != null ? (String[])this._inputNames.clone() : null;
        ret._outputNames = (String[])this._outputNames.clone();
        if (this._outputHops != null) {
            ret._outputHops = (ArrayList)this._outputHops.clone();
        }
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        return false;
    }

    public static enum FunctionType {
        DML,
        EXTERNAL_MEM,
        EXTERNAL_FILE,
        MULTIRETURN_BUILTIN,
        UNKNOWN;

    }
}

