/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hyracks.algebricks.rewriter.rules;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalOperatorTag;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalVariable;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractLogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.VariableReferenceExpression;
import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AbstractBinaryJoinOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AbstractLogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.AssignOperator;
import org.apache.hyracks.algebricks.core.algebra.operators.logical.visitors.VariableUtilities;
import org.apache.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;

public class PushFunctionsBelowJoin
implements IAlgebraicRewriteRule {
    private final Set<FunctionIdentifier> toPushFuncIdents;
    private final List<Mutable<ILogicalExpression>> funcExprs = new ArrayList<Mutable<ILogicalExpression>>();
    private final List<LogicalVariable> usedVars = new ArrayList<LogicalVariable>();
    private final List<LogicalVariable> liveVars = new ArrayList<LogicalVariable>();

    public PushFunctionsBelowJoin(Set<FunctionIdentifier> toPushFuncIdents) {
        this.toPushFuncIdents = toPushFuncIdents;
    }

    public boolean rewritePre(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
        return false;
    }

    public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
        AbstractLogicalOperator op = (AbstractLogicalOperator)opRef.getValue();
        if (op.getOperatorTag() != LogicalOperatorTag.ASSIGN) {
            return false;
        }
        AssignOperator assignOp = (AssignOperator)op;
        Mutable<ILogicalOperator> joinOpRef = this.findJoinOp((Mutable<ILogicalOperator>)((Mutable)assignOp.getInputs().get(0)));
        if (joinOpRef == null) {
            return false;
        }
        AbstractBinaryJoinOperator joinOp = (AbstractBinaryJoinOperator)joinOpRef.getValue();
        this.funcExprs.clear();
        this.gatherFunctionCalls(assignOp, this.funcExprs);
        if (this.funcExprs.isEmpty()) {
            return false;
        }
        boolean modified = false;
        if (this.pushDownFunctions(joinOp, 0, this.funcExprs, context)) {
            modified = true;
        }
        if (this.pushDownFunctions(joinOp, 1, this.funcExprs, context)) {
            modified = true;
        }
        if (modified) {
            context.computeAndSetTypeEnvironmentForOperator((ILogicalOperator)joinOp);
        }
        return modified;
    }

    private Mutable<ILogicalOperator> findJoinOp(Mutable<ILogicalOperator> opRef) {
        AbstractLogicalOperator op = (AbstractLogicalOperator)opRef.getValue();
        switch (op.getOperatorTag()) {
            case INNERJOIN: 
            case LEFTOUTERJOIN: {
                return opRef;
            }
            case GROUP: 
            case AGGREGATE: 
            case DISTINCT: 
            case UNNEST_MAP: {
                return null;
            }
        }
        Iterator iterator = op.getInputs().iterator();
        if (iterator.hasNext()) {
            Mutable childOpRef = (Mutable)iterator.next();
            return this.findJoinOp((Mutable<ILogicalOperator>)childOpRef);
        }
        return null;
    }

    private void gatherFunctionCalls(AssignOperator assignOp, List<Mutable<ILogicalExpression>> funcExprs) {
        for (Mutable exprRef : assignOp.getExpressions()) {
            this.gatherFunctionCalls((Mutable<ILogicalExpression>)exprRef, funcExprs);
        }
    }

    private void gatherFunctionCalls(Mutable<ILogicalExpression> exprRef, List<Mutable<ILogicalExpression>> funcExprs) {
        AbstractLogicalExpression expr = (AbstractLogicalExpression)exprRef.getValue();
        if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
            return;
        }
        AbstractFunctionCallExpression funcExpr = (AbstractFunctionCallExpression)expr;
        if (this.toPushFuncIdents.contains(funcExpr.getFunctionIdentifier())) {
            funcExprs.add(exprRef);
        }
        for (Mutable funcArg : funcExpr.getArguments()) {
            this.gatherFunctionCalls((Mutable<ILogicalExpression>)funcArg, funcExprs);
        }
    }

    private boolean pushDownFunctions(AbstractBinaryJoinOperator joinOp, int inputIndex, List<Mutable<ILogicalExpression>> funcExprs, IOptimizationContext context) throws AlgebricksException {
        ILogicalOperator joinInputOp = (ILogicalOperator)((Mutable)joinOp.getInputs().get(inputIndex)).getValue();
        this.liveVars.clear();
        VariableUtilities.getLiveVariables((ILogicalOperator)joinInputOp, this.liveVars);
        Iterator<Mutable<ILogicalExpression>> funcIter = funcExprs.iterator();
        ArrayList<LogicalVariable> assignVars = null;
        ArrayList<MutableObject> assignExprs = null;
        while (funcIter.hasNext()) {
            Mutable<ILogicalExpression> funcExprRef = funcIter.next();
            ILogicalExpression funcExpr = (ILogicalExpression)funcExprRef.getValue();
            this.usedVars.clear();
            funcExpr.getUsedVariables(this.usedVars);
            if (!this.liveVars.containsAll(this.usedVars)) continue;
            if (assignVars == null) {
                assignVars = new ArrayList<LogicalVariable>();
                assignExprs = new ArrayList<MutableObject>();
            }
            LogicalVariable replacementVar = context.newVar();
            assignVars.add(replacementVar);
            assignExprs.add(new MutableObject((Object)funcExpr));
            VariableReferenceExpression replacementVarRef = new VariableReferenceExpression(replacementVar);
            replacementVarRef.setSourceLocation(funcExpr.getSourceLocation());
            funcExprRef.setValue((Object)replacementVarRef);
            funcIter.remove();
        }
        if (assignVars != null) {
            AssignOperator newAssign = new AssignOperator(assignVars, assignExprs);
            newAssign.getInputs().add(new MutableObject((Object)joinInputOp));
            newAssign.setExecutionMode(joinOp.getExecutionMode());
            ((Mutable)joinOp.getInputs().get(inputIndex)).setValue((Object)newAssign);
            context.computeAndSetTypeEnvironmentForOperator((ILogicalOperator)newAssign);
            return true;
        }
        return false;
    }
}

