/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import org.apache.iotdb.db.queryengine.common.QueryId;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ResolvedFunction;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Assignments;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SimplePlanRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ApplyNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.CorrelatedJoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.JoinNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer;
import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.Util;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.GenericLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SearchedCaseExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator;
import org.apache.tsfile.read.common.type.BooleanType;
import org.apache.tsfile.read.common.type.LongType;
import org.apache.tsfile.read.common.type.Type;

public class TransformQuantifiedComparisonApplyToCorrelatedJoin
implements PlanOptimizer {
    private final Metadata metadata;

    public TransformQuantifiedComparisonApplyToCorrelatedJoin(Metadata metadata) {
        this.metadata = Objects.requireNonNull(metadata, "metadata is null");
    }

    @Override
    public PlanNode optimize(PlanNode plan, PlanOptimizer.Context context) {
        return SimplePlanRewriter.rewriteWith(new Rewriter(context.idAllocator(), context.getSymbolAllocator(), this.metadata), plan, null);
    }

    private static class Rewriter
    extends SimplePlanRewriter<PlanNode> {
        private final QueryId idAllocator;
        private final SymbolAllocator symbolAllocator;
        private final Metadata metadata;

        public Rewriter(QueryId idAllocator, SymbolAllocator symbolAllocator, Metadata metadata) {
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.symbolAllocator = Objects.requireNonNull(symbolAllocator, "symbolAllocator is null");
            this.metadata = Objects.requireNonNull(metadata, "metadata is null");
        }

        @Override
        public PlanNode visitApply(ApplyNode node, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            if (node.getSubqueryAssignments().size() != 1) {
                return context.defaultRewrite(node);
            }
            ApplyNode.SetExpression expression = (ApplyNode.SetExpression)Iterables.getOnlyElement(node.getSubqueryAssignments().values());
            if (expression instanceof ApplyNode.QuantifiedComparison) {
                return this.rewriteQuantifiedApplyNode(node, (ApplyNode.QuantifiedComparison)expression, context);
            }
            return context.defaultRewrite(node);
        }

        private PlanNode rewriteQuantifiedApplyNode(ApplyNode node, ApplyNode.QuantifiedComparison quantifiedComparison, SimplePlanRewriter.RewriteContext<PlanNode> context) {
            PlanNode subqueryPlan = context.rewrite(node.getSubquery());
            Symbol outputColumn = (Symbol)Iterables.getOnlyElement(subqueryPlan.getOutputSymbols());
            Type outputColumnType = this.symbolAllocator.getTypes().getTableModelType(outputColumn);
            Preconditions.checkState((boolean)outputColumnType.isOrderable(), (Object)"Subquery result type must be orderable");
            Symbol minValue = this.symbolAllocator.newSymbol("min", outputColumnType);
            Symbol maxValue = this.symbolAllocator.newSymbol("max", outputColumnType);
            Symbol countAllValue = this.symbolAllocator.newSymbol("count_all", (Type)LongType.getInstance());
            Symbol countNonNullValue = this.symbolAllocator.newSymbol("count_non_null", (Type)LongType.getInstance());
            ImmutableList outputColumnReferences = ImmutableList.of((Object)outputColumn.toSymbolReference());
            subqueryPlan = AggregationNode.singleAggregation(this.idAllocator.genPlanNodeId(), subqueryPlan, (Map<Symbol, AggregationNode.Aggregation>)ImmutableMap.of((Object)minValue, (Object)new AggregationNode.Aggregation(this.getResolvedBuiltInAggregateFunction("min", (List<Type>)ImmutableList.of((Object)outputColumnType)), (List<Expression>)outputColumnReferences, false, Optional.empty(), Optional.empty(), Optional.empty()), (Object)maxValue, (Object)new AggregationNode.Aggregation(this.getResolvedBuiltInAggregateFunction("max", (List<Type>)ImmutableList.of((Object)outputColumnType)), (List<Expression>)outputColumnReferences, false, Optional.empty(), Optional.empty(), Optional.empty()), (Object)countAllValue, (Object)new AggregationNode.Aggregation(this.getResolvedBuiltInAggregateFunction("count_all", (List<Type>)ImmutableList.of((Object)outputColumnType)), (List<Expression>)outputColumnReferences, false, Optional.empty(), Optional.empty(), Optional.empty()), (Object)countNonNullValue, (Object)new AggregationNode.Aggregation(this.getResolvedBuiltInAggregateFunction("count", (List<Type>)ImmutableList.of((Object)outputColumnType)), (List<Expression>)outputColumnReferences, false, Optional.empty(), Optional.empty(), Optional.empty())), AggregationNode.globalAggregation());
            CorrelatedJoinNode join = new CorrelatedJoinNode(node.getPlanNodeId(), context.rewrite(node.getInput()), subqueryPlan, node.getCorrelation(), JoinNode.JoinType.INNER, BooleanLiteral.TRUE_LITERAL, node.getOriginSubquery());
            Expression valueComparedToSubquery = this.rewriteUsingBounds(quantifiedComparison, minValue, maxValue, countAllValue, countNonNullValue);
            Symbol quantifiedComparisonSymbol = (Symbol)Iterables.getOnlyElement(node.getSubqueryAssignments().keySet());
            return this.projectExpressions(join, Assignments.of(quantifiedComparisonSymbol, valueComparedToSubquery));
        }

        private ResolvedFunction getResolvedBuiltInAggregateFunction(String functionName, List<Type> argumentTypes) {
            return Util.getResolvedBuiltInAggregateFunction(this.metadata, functionName, argumentTypes);
        }

        public Expression rewriteUsingBounds(ApplyNode.QuantifiedComparison quantifiedComparison, Symbol minValue, Symbol maxValue, Symbol countAllValue, Symbol countNonNullValue) {
            Function<List, Expression> quantifier;
            BooleanLiteral emptySetResult;
            if (quantifiedComparison.getQuantifier() == ApplyNode.Quantifier.ALL) {
                emptySetResult = BooleanLiteral.TRUE_LITERAL;
                quantifier = IrUtils::combineConjuncts;
            } else {
                emptySetResult = BooleanLiteral.FALSE_LITERAL;
                quantifier = IrUtils::combineDisjuncts;
            }
            Expression comparisonWithExtremeValue = this.getBoundComparisons(quantifiedComparison, minValue, maxValue);
            return new SimpleCaseExpression(countAllValue.toSymbolReference(), (List<WhenClause>)ImmutableList.of((Object)new WhenClause(new GenericLiteral("INT64", "0"), emptySetResult)), quantifier.apply((List)ImmutableList.of((Object)comparisonWithExtremeValue, (Object)new SearchedCaseExpression((List<WhenClause>)ImmutableList.of((Object)new WhenClause(new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, countAllValue.toSymbolReference(), countNonNullValue.toSymbolReference()), new Cast(new NullLiteral(), TypeSignatureTranslator.toSqlType((Type)BooleanType.BOOLEAN)))), emptySetResult))));
        }

        private Expression getBoundComparisons(ApplyNode.QuantifiedComparison quantifiedComparison, Symbol minValue, Symbol maxValue) {
            if (Rewriter.mapOperator(quantifiedComparison) == ComparisonExpression.Operator.EQUAL && quantifiedComparison.getQuantifier() == ApplyNode.Quantifier.ALL) {
                return IrUtils.combineConjuncts(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, minValue.toSymbolReference(), maxValue.toSymbolReference()), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, quantifiedComparison.getValue().toSymbolReference(), maxValue.toSymbolReference()));
            }
            if (EnumSet.of(ComparisonExpression.Operator.LESS_THAN, ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, ComparisonExpression.Operator.GREATER_THAN, ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL).contains((Object)Rewriter.mapOperator(quantifiedComparison))) {
                Symbol boundValue = Rewriter.shouldCompareValueWithLowerBound(quantifiedComparison) ? minValue : maxValue;
                return new ComparisonExpression(Rewriter.mapOperator(quantifiedComparison), quantifiedComparison.getValue().toSymbolReference(), boundValue.toSymbolReference());
            }
            throw new IllegalArgumentException("Unsupported quantified comparison: " + quantifiedComparison);
        }

        private static ComparisonExpression.Operator mapOperator(ApplyNode.QuantifiedComparison quantifiedComparison) {
            switch (quantifiedComparison.getOperator()) {
                case EQUAL: {
                    return ComparisonExpression.Operator.EQUAL;
                }
                case NOT_EQUAL: {
                    return ComparisonExpression.Operator.NOT_EQUAL;
                }
                case LESS_THAN: {
                    return ComparisonExpression.Operator.LESS_THAN;
                }
                case LESS_THAN_OR_EQUAL: {
                    return ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
                }
                case GREATER_THAN: {
                    return ComparisonExpression.Operator.GREATER_THAN;
                }
                case GREATER_THAN_OR_EQUAL: {
                    return ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
                }
            }
            throw new IllegalArgumentException("Unexpected quantifiedComparison: " + (Object)((Object)quantifiedComparison.getOperator()));
        }

        private static boolean shouldCompareValueWithLowerBound(ApplyNode.QuantifiedComparison quantifiedComparison) {
            ComparisonExpression.Operator operator = Rewriter.mapOperator(quantifiedComparison);
            switch (quantifiedComparison.getQuantifier()) {
                case ALL: {
                    switch (operator) {
                        case LESS_THAN: 
                        case LESS_THAN_OR_EQUAL: {
                            return true;
                        }
                        case GREATER_THAN: 
                        case GREATER_THAN_OR_EQUAL: {
                            return false;
                        }
                    }
                    throw new IllegalArgumentException("Unexpected value: " + (Object)((Object)operator));
                }
                case ANY: 
                case SOME: {
                    switch (operator) {
                        case LESS_THAN: 
                        case LESS_THAN_OR_EQUAL: {
                            return false;
                        }
                        case GREATER_THAN: 
                        case GREATER_THAN_OR_EQUAL: {
                            return true;
                        }
                    }
                    throw new IllegalArgumentException("Unexpected value: " + (Object)((Object)operator));
                }
            }
            throw new IllegalArgumentException("Unexpected Quantifier: " + (Object)((Object)quantifiedComparison.getQuantifier()));
        }

        private ProjectNode projectExpressions(PlanNode input, Assignments subqueryAssignments) {
            Assignments assignments = Assignments.builder().putIdentities(input.getOutputSymbols()).putAll(subqueryAssignments).build();
            return new ProjectNode(this.idAllocator.genPlanNodeId(), input, assignments);
        }
    }
}

