/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.cost;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.cost.FederatedCostEstimator;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class HopRel {
    protected final Hop hopRef;
    protected final FEDInstruction.FederatedOutput fedOut;
    protected Types.ExecType execType;
    protected FTypes.FType fType;
    protected FederatedCost cost;
    protected final Set<Long> costPointerSet = new HashSet<Long>();
    protected List<Hop> inputHops;
    protected List<HopRel> inputDependency = new ArrayList<HopRel>();

    public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, MemoTable hopRelMemo) {
        this(associatedHop, fedOut, null, hopRelMemo, associatedHop.getInput());
    }

    public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, MemoTable hopRelMemo, ArrayList<Hop> inputs) {
        this(associatedHop, fedOut, null, hopRelMemo, inputs);
    }

    private HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FTypes.FType fType, List<Hop> inputs) {
        this.hopRef = associatedHop;
        this.fedOut = fedOut;
        this.fType = fType;
        this.inputHops = inputs;
    }

    public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FTypes.FType fType, MemoTable hopRelMemo, ArrayList<Hop> inputs) {
        this(associatedHop, fedOut, fType, inputs);
        this.setInputDependency(hopRelMemo);
        this.cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
        this.setExecType();
    }

    public HopRel(Hop associatedHop, FEDInstruction.FederatedOutput fedOut, FTypes.FType fType, MemoTable hopRelMemo, List<Hop> inputs, List<FTypes.FType> inputDependency) {
        this(associatedHop, fedOut, fType, inputs);
        this.setInputFTypeDependency(inputs, inputDependency, hopRelMemo);
        this.cost = FederatedCostEstimator.costEstimate(this, hopRelMemo);
        this.setExecType();
    }

    private void setInputFTypeDependency(List<Hop> inputs, List<FTypes.FType> inputDependency, MemoTable hopRelMemo) {
        for (int i = 0; i < inputs.size(); ++i) {
            this.inputDependency.add(hopRelMemo.getHopRel(inputs.get(i), inputDependency.get(i)));
        }
        this.validateInputDependency();
    }

    private void setExecType() {
        if (this.inputDependency.stream().anyMatch(HopRel::hasFederatedOutput) || HopRewriteUtils.isData(this.hopRef, Types.OpOpData.FEDERATED)) {
            this.execType = Types.ExecType.FED;
        }
    }

    public void addCostPointer(long hopID) {
        this.costPointerSet.add(hopID);
    }

    public boolean existingCostPointer(long currentHopID) {
        if (this.costPointerSet.contains(currentHopID)) {
            return this.costPointerSet.size() > 1;
        }
        return this.costPointerSet.size() > 0;
    }

    public boolean hasLocalOutput() {
        return this.fedOut == FEDInstruction.FederatedOutput.LOUT;
    }

    public boolean hasFederatedOutput() {
        return this.fedOut == FEDInstruction.FederatedOutput.FOUT;
    }

    public FEDInstruction.FederatedOutput getFederatedOutput() {
        return this.fedOut;
    }

    public List<HopRel> getInputDependency() {
        return this.inputDependency;
    }

    public Hop getHopRef() {
        return this.hopRef;
    }

    public FTypes.FType getFType() {
        return this.fType;
    }

    public void setFType(FTypes.FType fType) {
        this.fType = fType;
    }

    public Types.ExecType getExecType() {
        return this.execType;
    }

    private HopRel getFOUTHopRel(Hop hop, MemoTable hopRelMemo) {
        return hopRelMemo.getFederatedOutputAlternativeOrNull(hop);
    }

    private void setInputDependency(MemoTable hopRelMemo) {
        if (this.inputHops != null && this.inputHops.size() > 0) {
            if (this.fedOut == FEDInstruction.FederatedOutput.FOUT && !this.hopRef.isFederatedDataOp()) {
                int lowestFOUTIndex = 0;
                HopRel lowestFOUTHopRel = this.getFOUTHopRel(this.inputHops.get(0), hopRelMemo);
                for (int i = 1; i < this.inputHops.size(); ++i) {
                    Hop input = this.inputHops.get(i);
                    HopRel foutHopRel = this.getFOUTHopRel(input, hopRelMemo);
                    if (lowestFOUTHopRel == null) {
                        lowestFOUTHopRel = foutHopRel;
                        lowestFOUTIndex = i;
                        continue;
                    }
                    if (foutHopRel == null || !(foutHopRel.getCost() < lowestFOUTHopRel.getCost())) continue;
                    lowestFOUTHopRel = foutHopRel;
                    lowestFOUTIndex = i;
                }
                HopRel[] inputHopRels = new HopRel[this.inputHops.size()];
                for (int i = 0; i < this.inputHops.size(); ++i) {
                    if (i != lowestFOUTIndex) {
                        Hop input = this.inputHops.get(i);
                        inputHopRels[i] = hopRelMemo.getMinCostAlternative(input);
                        continue;
                    }
                    inputHopRels[i] = lowestFOUTHopRel;
                }
                this.inputDependency.addAll(Arrays.asList(inputHopRels));
            } else {
                this.inputDependency.addAll(this.inputHops.stream().map(hopRelMemo::getMinCostAlternative).collect(Collectors.toList()));
            }
        }
        this.validateInputDependency();
    }

    private void validateInputDependency() {
        for (int i = 0; i < this.inputDependency.size(); ++i) {
            if (this.inputDependency.get(i) != null) continue;
            throw new DMLException("HopRel input number " + i + " (" + this.hopRef.getInput(i) + ") is null for root: \n" + this);
        }
    }

    public double getCost() {
        return this.cost.getTotal();
    }

    public FederatedCost getCostObject() {
        return this.cost;
    }

    public String toString() {
        StringBuilder strB = new StringBuilder();
        strB.append(this.getClass().getSimpleName());
        strB.append(" {HopID: ");
        strB.append(this.hopRef.getHopID());
        strB.append(", Opcode: ");
        strB.append(this.hopRef.getOpString());
        strB.append(", FedOut: ");
        strB.append((Object)this.fedOut);
        strB.append(", Cost: ");
        strB.append(this.cost.getTotal());
        strB.append(", Inputs: ");
        strB.append(this.inputDependency.stream().map(i -> "{" + i.getHopRef().getHopID() + ", " + i.getFederatedOutput() + "}").collect(Collectors.toList()));
        strB.append("}");
        return strB.toString();
    }
}

