/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.LanguageException;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

public class IPAPassFlagNonDeterminism
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return !fgraph.containsSecondOrderCall();
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        if (!LineageCacheConfig.isMultiLevelReuse() && !DMLScript.LINEAGE_ESTIMATE) {
            return false;
        }
        try {
            FunctionStatement fnstmt;
            FunctionStatementBlock fsblock;
            HashSet<String> ndfncs = new HashSet<String>();
            for (String fkey2 : fgraph.getReachableFunctions()) {
                fsblock = prog.getFunctionStatementBlock(fkey2);
                fnstmt = (FunctionStatement)fsblock.getStatement(0);
                String fname = DMLProgram.splitFunctionKey(fkey2)[1];
                if (!this.rIsNonDeterministicFnc(fname, fnstmt.getBody())) continue;
                ndfncs.add(fkey2);
            }
            this.propagate2Callers(fgraph, ndfncs, new HashSet<String>(), null);
            ndfncs.forEach(fkey -> {
                FunctionStatementBlock fsblock = prog.getFunctionStatementBlock((String)fkey);
                fsblock.setNondeterministic(true);
            });
            this.rMarkNondeterministicSBs(prog.getStatementBlocks(), ndfncs);
            for (String fkey2 : fgraph.getReachableFunctions()) {
                fsblock = prog.getFunctionStatementBlock(fkey2);
                fnstmt = (FunctionStatement)fsblock.getStatement(0);
                this.rMarkNondeterministicSBs(fnstmt.getBody(), ndfncs);
            }
        }
        catch (LanguageException ex) {
            throw new HopsException(ex);
        }
        return false;
    }

    private boolean rIsNonDeterministicFnc(String fname, List<StatementBlock> sbs) {
        boolean isND = false;
        for (StatementBlock sb : sbs) {
            if (isND) break;
            if (sb instanceof ForStatementBlock) {
                ForStatement fstmt = (ForStatement)sb.getStatement(0);
                isND = this.rIsNonDeterministicFnc(fname, fstmt.getBody());
                continue;
            }
            if (sb instanceof WhileStatementBlock) {
                WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
                isND = this.rIsNonDeterministicFnc(fname, wstmt.getBody());
                continue;
            }
            if (sb instanceof IfStatementBlock) {
                IfStatement ifstmt = (IfStatement)sb.getStatement(0);
                isND = this.rIsNonDeterministicFnc(fname, ifstmt.getIfBody());
                if (ifstmt.getElseBody() == null) continue;
                isND = this.rIsNonDeterministicFnc(fname, ifstmt.getElseBody());
                continue;
            }
            if (sb.getHops() == null) continue;
            Hop.resetVisitStatus(sb.getHops());
            for (Hop hop : sb.getHops()) {
                isND |= this.rIsNonDeterministicHop(hop);
            }
            Hop.resetVisitStatus(sb.getHops());
            sb.setNondeterministic(isND);
        }
        return isND;
    }

    private void rMarkNondeterministicSBs(List<StatementBlock> sbs, Set<String> ndfncs) {
        for (StatementBlock sb : sbs) {
            if (sb instanceof ForStatementBlock) {
                ForStatement fstmt = (ForStatement)sb.getStatement(0);
                this.rMarkNondeterministicSBs(fstmt.getBody(), ndfncs);
                continue;
            }
            if (sb instanceof WhileStatementBlock) {
                WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
                this.rMarkNondeterministicSBs(wstmt.getBody(), ndfncs);
                continue;
            }
            if (sb instanceof IfStatementBlock) {
                IfStatement ifstmt = (IfStatement)sb.getStatement(0);
                this.rMarkNondeterministicSBs(ifstmt.getIfBody(), ndfncs);
                if (ifstmt.getElseBody() == null) continue;
                this.rMarkNondeterministicSBs(ifstmt.getElseBody(), ndfncs);
                continue;
            }
            if (sb.getHops() == null) continue;
            boolean callsND = false;
            Hop.resetVisitStatus(sb.getHops());
            for (Hop hop : sb.getHops()) {
                callsND |= this.rMarkNondeterministicHop(hop, ndfncs);
            }
            Hop.resetVisitStatus(sb.getHops());
            if (!callsND) continue;
            sb.setNondeterministic(callsND);
        }
    }

    private boolean rMarkNondeterministicHop(Hop hop, Set<String> ndfncs) {
        boolean callsND;
        if (hop.isVisited()) {
            return false;
        }
        boolean bl = callsND = hop instanceof FunctionOp && ndfncs.contains(hop.getName());
        if (!callsND) {
            for (Hop hi : hop.getInput()) {
                callsND |= this.rMarkNondeterministicHop(hi, ndfncs);
            }
        }
        hop.setVisited();
        return callsND;
    }

    private boolean rIsNonDeterministicHop(Hop hop) {
        if (hop.isVisited()) {
            return false;
        }
        boolean isND = HopRewriteUtils.isDataGenOpWithNonDeterminism(hop);
        if (!isND) {
            for (Hop hi : hop.getInput()) {
                isND |= this.rIsNonDeterministicHop(hi);
            }
        }
        hop.setVisited();
        return isND;
    }

    private void propagate2Callers(FunctionCallGraph fgraph, Set<String> ndfncs, Set<String> fstack, String fkey) {
        Set<String> cfkeys = fgraph.getCalledFunctions(fkey);
        if (cfkeys != null) {
            for (String cfkey : cfkeys) {
                if (fstack.contains(cfkey) && fgraph.isRecursiveFunction(cfkey)) {
                    if (!ndfncs.contains(cfkey) || fkey == null) continue;
                    ndfncs.add(fkey);
                    continue;
                }
                fstack.add(cfkey);
                this.propagate2Callers(fgraph, ndfncs, fstack, cfkey);
                fstack.remove(cfkey);
                if (!ndfncs.contains(cfkey) || fkey == null) continue;
                ndfncs.add(fkey);
            }
        }
    }
}

