/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.cost;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.cost.FederatedCost;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class FederatedCostEstimator {
    private static final Log LOG = LogFactory.getLog((String)FederatedCostEstimator.class.getName());
    public static int DEFAULT_MEMORY_ESTIMATE = 8;
    public static double WORKER_NETWORK_BANDWIDTH_BYTES_PS = 1.073741824E9;
    public static double WORKER_COMPUTE_BANDWIDTH_FLOPS = 2.68435456E9;
    public static double WORKER_DEGREE_OF_PARALLELISM = 8.0;
    public static double WORKER_READ_BANDWIDTH_BYTES_PS = 3.758096384E9;

    public FederatedCost costEstimate(DMLProgram dmlProgram) {
        dmlProgram.updateRepetitionEstimates();
        FederatedCost programTotalCost = new FederatedCost();
        for (StatementBlock stmBlock : dmlProgram.getStatementBlocks()) {
            programTotalCost.addInputTotalCost(this.costEstimate(stmBlock).getTotal());
        }
        return programTotalCost;
    }

    private FederatedCost costEstimate(StatementBlock sb) {
        if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock whileSB = (WhileStatementBlock)sb;
            FederatedCost whileSBCost = this.costEstimate(whileSB.getPredicateHops());
            for (Statement statement : whileSB.getStatements()) {
                WhileStatement whileStatement = (WhileStatement)statement;
                for (StatementBlock bodyBlock : whileStatement.getBody()) {
                    whileSBCost.addInputTotalCost(this.costEstimate(bodyBlock));
                }
            }
            return whileSBCost;
        }
        if (sb instanceof IfStatementBlock) {
            IfStatementBlock ifSB = (IfStatementBlock)sb;
            FederatedCost ifSBCost = new FederatedCost();
            for (Statement statement : ifSB.getStatements()) {
                IfStatement ifStatement = (IfStatement)statement;
                for (StatementBlock ifBodySB : ifStatement.getIfBody()) {
                    ifSBCost.addInputTotalCost(this.costEstimate(ifBodySB));
                }
                for (StatementBlock elseBodySB : ifStatement.getElseBody()) {
                    ifSBCost.addInputTotalCost(this.costEstimate(elseBodySB));
                }
            }
            ifSBCost.addInputTotalCost(this.costEstimate(ifSB.getPredicateHops()));
            return ifSBCost;
        }
        if (sb instanceof ForStatementBlock) {
            ForStatementBlock forSB = (ForStatementBlock)sb;
            ArrayList<Hop> predicateHops = new ArrayList<Hop>();
            predicateHops.add(forSB.getFromHops());
            predicateHops.add(forSB.getToHops());
            predicateHops.add(forSB.getIncrementHops());
            FederatedCost forSBCost = this.costEstimate(predicateHops);
            for (Statement statement : forSB.getStatements()) {
                ForStatement forStatement = (ForStatement)statement;
                for (StatementBlock forStatementBlockBody : forStatement.getBody()) {
                    forSBCost.addInputTotalCost(this.costEstimate(forStatementBlockBody));
                }
            }
            return forSBCost;
        }
        if (sb instanceof FunctionStatementBlock) {
            FederatedCost funcCost = this.addInitialInputCost(sb);
            FunctionStatementBlock funcSB = (FunctionStatementBlock)sb;
            for (Statement statement : funcSB.getStatements()) {
                FunctionStatement funcStatement = (FunctionStatement)statement;
                for (StatementBlock funcStatementBody : funcStatement.getBody()) {
                    funcCost.addInputTotalCost(this.costEstimate(funcStatementBody));
                }
            }
            return funcCost;
        }
        return this.costEstimate(sb.getHops());
    }

    private FederatedCost addInitialInputCost(StatementBlock sb) {
        FederatedCost basicCost = new FederatedCost();
        for (StatementBlock childSB : sb.getDMLProg().getStatementBlocks()) {
            basicCost.addInputTotalCost(this.costEstimate(childSB).getTotal());
        }
        return basicCost;
    }

    private FederatedCost costEstimate(ArrayList<Hop> roots) {
        FederatedCost basicCost = new FederatedCost();
        for (Hop root : roots) {
            basicCost.addInputTotalCost(this.costEstimate(root));
        }
        return basicCost;
    }

    public FederatedCost costEstimate(Hop root) {
        if (root.federatedCostInitialized()) {
            return root.getFederatedCost();
        }
        boolean hasFederatedInput = root.someInputFederated();
        double inputCosts = root.getInput().stream().mapToDouble(in -> in.federatedCostInitialized() ? 0.0 : this.costEstimate((Hop)in).getTotal()).sum();
        double inputTransferCost = this.inputTransferCostEstimate(hasFederatedInput, root);
        double computingCost = ComputeCost.getHOPComputeCost(root);
        if (hasFederatedInput) {
            int numWorkers = (int)root.getInput().stream().filter(Hop::hasFederatedOutput).count();
            computingCost /= (double)numWorkers * WORKER_DEGREE_OF_PARALLELISM * WORKER_COMPUTE_BANDWIDTH_FLOPS;
        } else {
            computingCost /= WORKER_DEGREE_OF_PARALLELISM * WORKER_COMPUTE_BANDWIDTH_FLOPS;
        }
        double outputTransferCost = root.hasLocalOutput() && (hasFederatedInput || root.isFederatedDataOp()) ? root.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0.0;
        double readCost = root.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
        double rootRepetitions = root.getRepetitions();
        FederatedCost rootFedCost = new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts, rootRepetitions);
        root.setFederatedCost(rootFedCost);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)FederatedCostEstimator.getCostInfo(root));
        }
        return rootFedCost;
    }

    public static FederatedCost costEstimate(HopRel root, MemoTable hopRelMemo) {
        if (hopRelMemo.containsHopRel(root)) {
            return root.getCostObject();
        }
        boolean hasFederatedInput = root.inputDependency.stream().anyMatch(HopRel::hasFederatedOutput);
        double inputCosts = root.inputDependency.stream().mapToDouble(in -> {
            double inCost = in.existingCostPointer(root.hopRef.getHopID()) ? 0.0 : FederatedCostEstimator.costEstimate(in, hopRelMemo).getTotal();
            in.addCostPointer(root.hopRef.getHopID());
            return inCost;
        }).sum();
        double inputTransferCost = FederatedCostEstimator.inputTransferCostEstimate(hasFederatedInput, root);
        double computingCost = ComputeCost.getHOPComputeCost(root.hopRef);
        if (hasFederatedInput) {
            int numWorkers = (int)root.inputDependency.stream().filter(HopRel::hasFederatedOutput).count();
            computingCost /= (double)numWorkers * WORKER_DEGREE_OF_PARALLELISM * WORKER_COMPUTE_BANDWIDTH_FLOPS;
        } else {
            computingCost /= WORKER_DEGREE_OF_PARALLELISM * WORKER_COMPUTE_BANDWIDTH_FLOPS;
        }
        double outputTransferCost = root.hasLocalOutput() && (hasFederatedInput || root.hopRef.isFederatedDataOp()) ? root.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_NETWORK_BANDWIDTH_BYTES_PS : 0.0;
        double readCost = root.hopRef.getInputMemEstimate(DEFAULT_MEMORY_ESTIMATE) / WORKER_READ_BANDWIDTH_BYTES_PS;
        double rootRepetitions = root.hopRef.getRepetitions();
        return new FederatedCost(readCost, inputTransferCost, outputTransferCost, computingCost, inputCosts, rootRepetitions);
    }

    private static double inputTransferCostEstimate(boolean hasFederatedInput, HopRel root) {
        if (hasFederatedInput) {
            return root.inputDependency.stream().filter(input -> root.hopRef.isFederatedDataOp() ? input.hasFederatedOutput() : input.hasLocalOutput()).mapToDouble(in -> in.hopRef.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE)).sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
        }
        return 0.0;
    }

    private double inputTransferCostEstimate(boolean hasFederatedInput, Hop root) {
        if (hasFederatedInput) {
            return root.getInput().stream().filter(input -> root.isFederatedDataOp() ? input.hasFederatedOutput() : input.hasLocalOutput()).mapToDouble(in -> in.getOutputMemEstimate(DEFAULT_MEMORY_ESTIMATE)).sum() / WORKER_NETWORK_BANDWIDTH_BYTES_PS;
        }
        return 0.0;
    }

    private static String getCostInfo(Hop root) {
        String sep = System.getProperty("line.separator");
        StringBuilder costInfo = new StringBuilder();
        costInfo.append(root).append(sep).append("Is federated: ").append(root.isFederated()).append(" Has federated output: ").append(root.hasFederatedOutput()).append(root.getText()).append(sep).append("Pure computeCost: " + ComputeCost.getHOPComputeCost(root)).append(" Dim1: " + root.getDim1() + " Dim2: " + root.getDim2()).append(sep).append(root.getFederatedCost().toString()).append(sep);
        return costInfo.toString();
    }
}

