/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import org.apache.hadoop.hive.ql.exec.ColumnInfo;
import org.apache.hadoop.hive.ql.exec.JoinOperator;
import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
import org.apache.hadoop.hive.ql.exec.Operator;
import org.apache.hadoop.hive.ql.exec.OperatorFactory;
import org.apache.hadoop.hive.ql.exec.ReduceSinkOperator;
import org.apache.hadoop.hive.ql.exec.RowSchema;
import org.apache.hadoop.hive.ql.exec.SelectOperator;
import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
import org.apache.hadoop.hive.ql.lib.Node;
import org.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.apache.hadoop.hive.ql.lib.Rule;
import org.apache.hadoop.hive.ql.lib.RuleRegExp;
import org.apache.hadoop.hive.ql.optimizer.Transform;
import org.apache.hadoop.hive.ql.parse.ErrorMsg;
import org.apache.hadoop.hive.ql.parse.GenMapRedWalker;
import org.apache.hadoop.hive.ql.parse.OpParseContext;
import org.apache.hadoop.hive.ql.parse.ParseContext;
import org.apache.hadoop.hive.ql.parse.QBJoinTree;
import org.apache.hadoop.hive.ql.parse.RowResolver;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.plan.PlanUtils;
import org.apache.hadoop.hive.ql.plan.exprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.exprNodeDesc;
import org.apache.hadoop.hive.ql.plan.joinCond;
import org.apache.hadoop.hive.ql.plan.joinDesc;
import org.apache.hadoop.hive.ql.plan.mapJoinDesc;
import org.apache.hadoop.hive.ql.plan.reduceSinkDesc;
import org.apache.hadoop.hive.ql.plan.selectDesc;
import org.apache.hadoop.hive.ql.plan.tableDesc;

public class MapJoinProcessor
implements Transform {
    private ParseContext pGraphContext = null;

    private Operator<? extends Serializable> putOpInsertMap(Operator<? extends Serializable> op, RowResolver rr) {
        OpParseContext ctx = new OpParseContext(rr);
        this.pGraphContext.getOpParseCtx().put(op, ctx);
        return op;
    }

    private MapJoinOperator convertMapJoin(ParseContext pctx, JoinOperator op, QBJoinTree joinTree, int mapJoinPos) throws SemanticException {
        joinCond[] condns;
        joinDesc desc = (joinDesc)op.getConf();
        for (joinCond condn : condns = desc.getConds()) {
            if (condn.getType() == 3) {
                throw new SemanticException(ErrorMsg.NO_OUTER_MAPJOIN.getMsg());
            }
            if (condn.getType() == 1 && condn.getLeft() != mapJoinPos) {
                throw new SemanticException(ErrorMsg.NO_OUTER_MAPJOIN.getMsg());
            }
            if (condn.getType() != 2 || condn.getRight() == mapJoinPos) continue;
            throw new SemanticException(ErrorMsg.NO_OUTER_MAPJOIN.getMsg());
        }
        RowResolver oldOutputRS = pctx.getOpParseCtx().get(op).getRR();
        RowResolver outputRS = new RowResolver();
        ArrayList<String> outputColumnNames = new ArrayList<String>();
        HashMap<Byte, List<exprNodeDesc>> keyExprMap = new HashMap<Byte, List<exprNodeDesc>>();
        HashMap<Byte, List<exprNodeDesc>> valueExprMap = new HashMap<Byte, List<exprNodeDesc>>();
        QBJoinTree leftSrc = joinTree.getJoinSrc();
        List<Operator<Serializable>> parentOps = op.getParentOperators();
        ArrayList<Operator<? extends Serializable>> newParentOps = new ArrayList<Operator<? extends Serializable>>();
        ArrayList<Operator<Serializable>> oldReduceSinkParentOps = new ArrayList<Operator<Serializable>>();
        HashMap<String, exprNodeDesc> colExprMap = new HashMap<String, exprNodeDesc>();
        if (leftSrc != null) {
            Operator<Serializable> parentOp = parentOps.get(0);
            assert (parentOp.getParentOperators().size() == 1);
            Operator<Serializable> grandParentOp = parentOp.getParentOperators().get(0);
            oldReduceSinkParentOps.add(parentOp);
            grandParentOp.removeChild(parentOp);
            newParentOps.add(grandParentOp);
        }
        int pos = 0;
        for (String string : joinTree.getBaseSrc()) {
            if (string != null) {
                Operator<Serializable> parentOp = parentOps.get(pos);
                assert (parentOp.getParentOperators().size() == 1);
                Operator<Serializable> grandParentOp = parentOp.getParentOperators().get(0);
                grandParentOp.removeChild(parentOp);
                oldReduceSinkParentOps.add(parentOp);
                newParentOps.add(grandParentOp);
            }
            ++pos;
        }
        for (pos = 0; pos < newParentOps.size(); ++pos) {
            ReduceSinkOperator oldPar = (ReduceSinkOperator)oldReduceSinkParentOps.get(pos);
            reduceSinkDesc rsconf = (reduceSinkDesc)oldPar.getConf();
            Byte tag = (byte)rsconf.getTag();
            ArrayList<exprNodeDesc> arrayList = rsconf.getKeyCols();
            keyExprMap.put(tag, arrayList);
        }
        for (pos = 0; pos < newParentOps.size(); ++pos) {
            RowResolver inputRS = this.pGraphContext.getOpParseCtx().get(newParentOps.get(pos)).getRR();
            ArrayList<exprNodeColumnDesc> values = new ArrayList<exprNodeColumnDesc>();
            for (String string : inputRS.getTableNames()) {
                HashMap<String, ColumnInfo> rrMap = inputRS.getFieldMap(string);
                for (String field : rrMap.keySet()) {
                    ColumnInfo valueInfo = inputRS.get(string, field);
                    ColumnInfo oldValueInfo = oldOutputRS.get(string, field);
                    if (oldValueInfo == null) continue;
                    String string2 = oldValueInfo.getInternalName();
                    if (outputRS.get(string, field) != null) continue;
                    outputColumnNames.add(string2);
                    exprNodeColumnDesc colDesc = new exprNodeColumnDesc(valueInfo.getType(), valueInfo.getInternalName(), valueInfo.getTabAlias(), valueInfo.getIsPartitionCol());
                    values.add(colDesc);
                    outputRS.put(string, field, new ColumnInfo(string2, valueInfo.getType(), valueInfo.getTabAlias(), valueInfo.getIsPartitionCol()));
                    colExprMap.put(string2, colDesc);
                }
            }
            valueExprMap.put(new Byte((byte)pos), values);
        }
        joinCond[] joinCondns = ((joinDesc)op.getConf()).getConds();
        Operator[] newPar = new Operator[newParentOps.size()];
        pos = 0;
        for (Operator operator : newParentOps) {
            newPar[pos++] = operator;
        }
        List keyCols = (List)keyExprMap.get(new Byte(0));
        StringBuilder stringBuilder = new StringBuilder();
        for (int i = 0; i < keyCols.size(); ++i) {
            stringBuilder.append("+");
        }
        tableDesc keyTableDesc = PlanUtils.getMapJoinKeyTableDesc(PlanUtils.getFieldSchemasFromColumnList(keyCols, "mapjoinkey"));
        ArrayList<tableDesc> valueTableDescs = new ArrayList<tableDesc>();
        for (pos = 0; pos < newParentOps.size(); ++pos) {
            List valueCols = (List)valueExprMap.get(new Byte((byte)pos));
            StringBuilder stringBuilder2 = new StringBuilder();
            for (int i = 0; i < valueCols.size(); ++i) {
                stringBuilder2.append("+");
            }
            tableDesc valueTableDesc = PlanUtils.getMapJoinValueTableDesc(PlanUtils.getFieldSchemasFromColumnList(valueCols, "mapjoinvalue"));
            valueTableDescs.add(valueTableDesc);
        }
        MapJoinOperator mapJoinOp = (MapJoinOperator)this.putOpInsertMap(OperatorFactory.getAndMakeChild(new mapJoinDesc(keyExprMap, keyTableDesc, valueExprMap, valueTableDescs, outputColumnNames, mapJoinPos, joinCondns), new RowSchema(outputRS.getColumnInfos()), newPar), outputRS);
        ((mapJoinDesc)mapJoinOp.getConf()).setReversedExprs(((joinDesc)op.getConf()).getReversedExprs());
        mapJoinOp.setColumnExprMap(colExprMap);
        List<Operator<? extends Serializable>> childOps = op.getChildOperators();
        for (Operator<Serializable> operator : childOps) {
            operator.replaceParent(op, mapJoinOp);
        }
        mapJoinOp.setChildOperators(childOps);
        mapJoinOp.setParentOperators(newParentOps);
        op.setChildOperators(null);
        op.setParentOperators(null);
        this.genSelectPlan(pctx, mapJoinOp);
        return mapJoinOp;
    }

    private void genSelectPlan(ParseContext pctx, MapJoinOperator input) throws SemanticException {
        List<Operator<? extends Serializable>> childOps = input.getChildOperators();
        input.setChildOperators(null);
        RowResolver inputRR = pctx.getOpParseCtx().get(input).getRR();
        ArrayList<exprNodeDesc> exprs = new ArrayList<exprNodeDesc>();
        ArrayList<String> outputs = new ArrayList<String>();
        ArrayList<String> outputCols = ((mapJoinDesc)input.getConf()).getOutputColumnNames();
        RowResolver outputRS = new RowResolver();
        HashMap<String, exprNodeDesc> colExprMap = new HashMap<String, exprNodeDesc>();
        for (int i = 0; i < outputCols.size(); ++i) {
            String internalName = (String)outputCols.get(i);
            String[] nm = inputRR.reverseLookup(internalName);
            ColumnInfo valueInfo = inputRR.get(nm[0], nm[1]);
            exprNodeColumnDesc colDesc = new exprNodeColumnDesc(valueInfo.getType(), valueInfo.getInternalName(), nm[0], valueInfo.getIsPartitionCol());
            exprs.add(colDesc);
            outputs.add(internalName);
            outputRS.put(nm[0], nm[1], new ColumnInfo(internalName, valueInfo.getType(), nm[0], valueInfo.getIsPartitionCol()));
            colExprMap.put(internalName, colDesc);
        }
        selectDesc select = new selectDesc(exprs, outputs, false);
        SelectOperator sel = (SelectOperator)this.putOpInsertMap(OperatorFactory.getAndMakeChild(select, new RowSchema(inputRR.getColumnInfos()), input), inputRR);
        sel.setColumnExprMap(colExprMap);
        sel.setChildOperators(childOps);
        for (Operator<? extends Serializable> ch : childOps) {
            ch.replaceParent(input, sel);
        }
    }

    private int mapSideJoin(JoinOperator op, QBJoinTree joinTree) throws SemanticException {
        int mapJoinPos = -1;
        if (joinTree.isMapSideJoin()) {
            int pos = 0;
            if (joinTree.getJoinSrc() != null) {
                mapJoinPos = pos;
            }
            for (String src : joinTree.getBaseSrc()) {
                if (src != null && !joinTree.getMapAliases().contains(src)) {
                    if (mapJoinPos >= 0) {
                        return -1;
                    }
                    mapJoinPos = pos;
                }
                ++pos;
            }
            if (mapJoinPos == -1) {
                throw new SemanticException(ErrorMsg.INVALID_MAPJOIN_HINT.getMsg(this.pGraphContext.getQB().getParseInfo().getHints()));
            }
        }
        return mapJoinPos;
    }

    @Override
    public ParseContext transform(ParseContext pactx) throws SemanticException {
        this.pGraphContext = pactx;
        ArrayList<MapJoinOperator> listMapJoinOps = new ArrayList<MapJoinOperator>();
        if (this.pGraphContext.getJoinContext() != null) {
            HashMap<JoinOperator, QBJoinTree> joinMap = new HashMap<JoinOperator, QBJoinTree>();
            Set<Map.Entry<JoinOperator, QBJoinTree>> joinCtx = this.pGraphContext.getJoinContext().entrySet();
            for (Map.Entry<JoinOperator, QBJoinTree> joinEntry : joinCtx) {
                QBJoinTree qbJoin;
                JoinOperator joinOp = joinEntry.getKey();
                int mapJoinPos = this.mapSideJoin(joinOp, qbJoin = joinEntry.getValue());
                if (mapJoinPos >= 0) {
                    listMapJoinOps.add(this.convertMapJoin(pactx, joinOp, qbJoin, mapJoinPos));
                    continue;
                }
                joinMap.put(joinOp, qbJoin);
            }
            this.pGraphContext.setJoinContext(joinMap);
        }
        ArrayList<MapJoinOperator> listMapJoinOpsNoRed = new ArrayList<MapJoinOperator>();
        LinkedHashMap<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
        opRules.put(new RuleRegExp(new String("R0"), "MAPJOIN%"), MapJoinProcessor.getCurrentMapJoin());
        opRules.put(new RuleRegExp(new String("R1"), "MAPJOIN%.*FS%"), MapJoinProcessor.getMapJoinFS());
        opRules.put(new RuleRegExp(new String("R2"), "MAPJOIN%.*RS%"), MapJoinProcessor.getMapJoinDefault());
        opRules.put(new RuleRegExp(new String("R3"), "MAPJOIN%.*MAPJOIN%"), MapJoinProcessor.getMapJoinDefault());
        opRules.put(new RuleRegExp(new String("R4"), "MAPJOIN%.*UNION%"), MapJoinProcessor.getMapJoinDefault());
        DefaultRuleDispatcher disp = new DefaultRuleDispatcher(MapJoinProcessor.getDefault(), opRules, new MapJoinWalkerCtx(listMapJoinOpsNoRed));
        GenMapRedWalker ogw = new GenMapRedWalker(disp);
        ArrayList<Node> topNodes = new ArrayList<Node>();
        topNodes.addAll(listMapJoinOps);
        ogw.startWalking(topNodes, null);
        this.pGraphContext.setListMapJoinOpsNoReducer(listMapJoinOpsNoRed);
        return this.pGraphContext;
    }

    public static NodeProcessor getMapJoinFS() {
        return new MapJoinFS();
    }

    public static NodeProcessor getMapJoinDefault() {
        return new MapJoinDefault();
    }

    public static NodeProcessor getDefault() {
        return new Default();
    }

    public static NodeProcessor getCurrentMapJoin() {
        return new CurrentMapJoin();
    }

    public static class MapJoinWalkerCtx
    implements NodeProcessorCtx {
        List<MapJoinOperator> listMapJoinsNoRed;
        List<MapJoinOperator> listRejectedMapJoins;
        MapJoinOperator currMapJoinOp;

        public MapJoinWalkerCtx(List<MapJoinOperator> listMapJoinsNoRed) {
            this.listMapJoinsNoRed = listMapJoinsNoRed;
            this.currMapJoinOp = null;
            this.listRejectedMapJoins = new ArrayList<MapJoinOperator>();
        }

        public List<MapJoinOperator> getListMapJoinsNoRed() {
            return this.listMapJoinsNoRed;
        }

        public void setListMapJoins(List<MapJoinOperator> listMapJoinsNoRed) {
            this.listMapJoinsNoRed = listMapJoinsNoRed;
        }

        public MapJoinOperator getCurrMapJoinOp() {
            return this.currMapJoinOp;
        }

        public void setCurrMapJoinOp(MapJoinOperator currMapJoinOp) {
            this.currMapJoinOp = currMapJoinOp;
        }

        public List<MapJoinOperator> getListRejectedMapJoins() {
            return this.listRejectedMapJoins;
        }

        public void setListRejectedMapJoins(List<MapJoinOperator> listRejectedMapJoins) {
            this.listRejectedMapJoins = listRejectedMapJoins;
        }
    }

    public static class Default
    implements NodeProcessor {
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            return null;
        }
    }

    public static class MapJoinDefault
    implements NodeProcessor {
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            MapJoinWalkerCtx ctx = (MapJoinWalkerCtx)procCtx;
            MapJoinOperator mapJoin = ctx.getCurrMapJoinOp();
            List<MapJoinOperator> listRejectedMapJoins = ctx.getListRejectedMapJoins();
            if (listRejectedMapJoins == null) {
                listRejectedMapJoins = new ArrayList<MapJoinOperator>();
            }
            listRejectedMapJoins.add(mapJoin);
            ctx.setListRejectedMapJoins(listRejectedMapJoins);
            return null;
        }
    }

    public static class MapJoinFS
    implements NodeProcessor {
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            MapJoinWalkerCtx ctx = (MapJoinWalkerCtx)procCtx;
            MapJoinOperator mapJoin = ctx.getCurrMapJoinOp();
            List<MapJoinOperator> listRejectedMapJoins = ctx.getListRejectedMapJoins();
            if (listRejectedMapJoins != null && listRejectedMapJoins.contains(mapJoin)) {
                return null;
            }
            List<MapJoinOperator> listMapJoinsNoRed = ctx.getListMapJoinsNoRed();
            if (listMapJoinsNoRed == null) {
                listMapJoinsNoRed = new ArrayList<MapJoinOperator>();
            }
            listMapJoinsNoRed.add(mapJoin);
            ctx.setListMapJoins(listMapJoinsNoRed);
            return null;
        }
    }

    public static class CurrentMapJoin
    implements NodeProcessor {
        @Override
        public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx, Object ... nodeOutputs) throws SemanticException {
            MapJoinWalkerCtx ctx = (MapJoinWalkerCtx)procCtx;
            MapJoinOperator mapJoin = (MapJoinOperator)nd;
            ctx.setCurrMapJoinOp(mapJoin);
            return null;
        }
    }
}

