/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.StatementBlock;

public class IPAPassInlineFunctions
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return true;
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        boolean ret = false;
        for (String fkey : fgraph.getReachableFunctions()) {
            FunctionStatementBlock fsb = prog.getFunctionStatementBlock(fkey);
            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
            if (fgraph.getFunctionCalls(fkey) == null) {
                ret = true;
                continue;
            }
            if (fstmt.getBody().size() != 1 || !HopRewriteUtils.isLastLevelStatementBlock(fstmt.getBody().get(0)) || IPAPassInlineFunctions.containsFunctionOp(fstmt.getBody().get(0).getHops()) || fgraph.getFunctionCalls(fkey).size() != 1 && IPAPassInlineFunctions.countOperators(fstmt.getBody().get(0).getHops()) > 10) continue;
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("IPA: Inline function '" + fkey + "'"));
            }
            ArrayList<Hop> hops = fstmt.getBody().get(0).getHops();
            List<FunctionOp> fcalls = fgraph.getFunctionCalls(fkey);
            List<StatementBlock> fcallsSB = fgraph.getFunctionCallsSB(fkey);
            boolean removedAll = true;
            for (int i = 0; i < fcalls.size(); ++i) {
                FunctionOp op = fcalls.get(i);
                if (LOG.isDebugEnabled()) {
                    LOG.debug((Object)("-- inline '" + fkey + "' at line " + op.getBeginLine()));
                }
                if (op.getInput().size() != fstmt.getInputParams().size() || op.getOutputVariableNames().length > fstmt.getOutputParams().size() || op.isPseudoFunctionCall()) {
                    removedAll = false;
                    continue;
                }
                ArrayList<Hop> hops2 = Recompiler.deepCopyHopsDag(hops);
                HashMap<String, Hop> inMap = new HashMap<String, Hop>();
                for (int j = 0; j < op.getInput().size(); ++j) {
                    String argName = op.getInputVariableNames()[j];
                    DataIdentifier di = fstmt.getInputParam(argName);
                    if (di == null) {
                        throw new HopsException("Non-existing named function argument: '" + argName + "' in function call '" + op.getFunctionKey() + "' (line " + op.getBeginLine() + ").");
                    }
                    inMap.put(argName, op.getInput().get(j));
                }
                IPAPassInlineFunctions.replaceTransientReads(hops2, inMap);
                HashMap<String, String> outMap = new HashMap<String, String>();
                String[] opOutputs = op.getOutputVariableNames();
                for (int j = 0; j < opOutputs.length; ++j) {
                    outMap.put(fstmt.getOutputParams().get(j).getName(), opOutputs[j]);
                }
                Iterator iterFout = hops2.iterator();
                while (iterFout.hasNext()) {
                    Hop out = (Hop)iterFout.next();
                    if (!HopRewriteUtils.isData(out, Types.OpOpData.TRANSIENTWRITE)) continue;
                    out.setName((String)outMap.get(out.getName()));
                    if (out.getName() != null) continue;
                    iterFout.remove();
                }
                fcallsSB.get(i).getHops().remove(op);
                fcallsSB.get(i).getHops().addAll(hops2);
            }
            if (!removedAll) continue;
            Set<String> fkeysTrans = fgraph.getCalledFunctions(fkey);
            fgraph.removeFunctionCalls(fkey);
            for (String fkeyTrans : fkeysTrans) {
                if (fgraph.isReachableFunction(fkeyTrans, true)) continue;
                fgraph.removeFunctionCalls(fkeyTrans);
            }
            ret = true;
        }
        return ret;
    }

    private static boolean containsFunctionOp(List<Hop> hops) {
        if (hops == null || hops.isEmpty()) {
            return false;
        }
        Hop.resetVisitStatus(hops);
        boolean ret = HopRewriteUtils.containsOp(hops, FunctionOp.class);
        Hop.resetVisitStatus(hops);
        return ret;
    }

    private static int countOperators(List<Hop> hops) {
        if (hops == null || hops.isEmpty()) {
            return 0;
        }
        Hop.resetVisitStatus(hops);
        int count = 0;
        for (Hop hop : hops) {
            count += IPAPassInlineFunctions.rCountOperators(hop);
        }
        Hop.resetVisitStatus(hops);
        return count;
    }

    private static int rCountOperators(Hop current) {
        if (current.isVisited()) {
            return 0;
        }
        int count = !(current instanceof DataOp) && !(current instanceof LiteralOp) ? 1 : 0;
        for (Hop c : current.getInput()) {
            count += IPAPassInlineFunctions.rCountOperators(c);
        }
        current.setVisited();
        return count;
    }

    private static void replaceTransientReads(List<Hop> hops, Map<String, Hop> inMap) {
        Hop.resetVisitStatus(hops);
        for (Hop hop : hops) {
            IPAPassInlineFunctions.rReplaceTransientReads(hop, inMap);
        }
        Hop.resetVisitStatus(hops);
    }

    private static void rReplaceTransientReads(Hop current, Map<String, Hop> inMap) {
        if (current.isVisited()) {
            return;
        }
        for (int i = 0; i < current.getInput().size(); ++i) {
            Hop c = current.getInput().get(i);
            IPAPassInlineFunctions.rReplaceTransientReads(c, inMap);
            if (!HopRewriteUtils.isData(c, Types.OpOpData.TRANSIENTREAD)) continue;
            HopRewriteUtils.replaceChildReference(current, c, inMap.get(c.getName()));
        }
        current.setVisited();
    }
}

