/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
import org.apache.iotdb.db.queryengine.plan.analyze.TypeProvider;
import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef;
import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.OperatorNotFoundException;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.security.AccessControl;
import org.apache.iotdb.db.queryengine.plan.relational.security.AllowAllAccessControl;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ArithmeticBinaryExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ArithmeticUnaryExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.AstVisitor;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BetweenPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BinaryLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.BooleanLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Cast;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CoalesceExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentDatabase;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentUser;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DoubleLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Extract;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.FunctionCall;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.GenericLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IfExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.InListExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.InPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IsNotNullPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.IsNullPredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LikePredicate;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NotExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullIfExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Row;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SearchedCaseExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignatureTranslator;
import org.apache.tsfile.read.common.type.BlobType;
import org.apache.tsfile.read.common.type.BooleanType;
import org.apache.tsfile.read.common.type.DateType;
import org.apache.tsfile.read.common.type.DoubleType;
import org.apache.tsfile.read.common.type.IntType;
import org.apache.tsfile.read.common.type.LongType;
import org.apache.tsfile.read.common.type.RowType;
import org.apache.tsfile.read.common.type.StringType;
import org.apache.tsfile.read.common.type.TimestampType;
import org.apache.tsfile.read.common.type.Type;
import org.apache.tsfile.read.common.type.UnknownType;

public class IrTypeAnalyzer {
    private final PlannerContext plannerContext;

    public IrTypeAnalyzer(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    public Map<NodeRef<Expression>, Type> getTypes(SessionInfo session, TypeProvider inputTypes, Iterable<Expression> expressions) {
        Visitor visitor = new Visitor(this.plannerContext, session, inputTypes);
        for (Expression expression : expressions) {
            visitor.process((Node)expression, new Context((Map<Symbol, Type>)ImmutableMap.of()));
        }
        return visitor.getTypes();
    }

    public Map<NodeRef<Expression>, Type> getTypes(SessionInfo session, TypeProvider inputTypes, Expression expression) {
        return this.getTypes(session, inputTypes, (Iterable<Expression>)ImmutableList.of((Object)expression));
    }

    public Type getType(SessionInfo session, TypeProvider inputTypes, Expression expression) {
        return this.getTypes(session, inputTypes, expression).get(NodeRef.of(expression));
    }

    private static class Visitor
    extends AstVisitor<Type, Context> {
        private static final AccessControl ALLOW_ALL_ACCESS_CONTROL = new AllowAllAccessControl();
        private final PlannerContext plannerContext;
        private final SessionInfo session;
        private final TypeProvider symbolTypes;
        private final Map<NodeRef<Expression>, Type> expressionTypes = new LinkedHashMap<NodeRef<Expression>, Type>();

        public Visitor(PlannerContext plannerContext, SessionInfo session, TypeProvider symbolTypes) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.symbolTypes = Objects.requireNonNull(symbolTypes, "symbolTypes is null");
        }

        public Map<NodeRef<Expression>, Type> getTypes() {
            return this.expressionTypes;
        }

        private Type setExpressionType(Expression expression, Type type) {
            Objects.requireNonNull(expression, "expression cannot be null");
            Objects.requireNonNull(type, "type cannot be null");
            this.expressionTypes.put(NodeRef.of(expression), type);
            return type;
        }

        @Override
        public Type process(Node node, Context context) {
            Type type;
            if (node instanceof Expression && (type = this.expressionTypes.get(NodeRef.of((Expression)node))) != null) {
                return type;
            }
            return (Type)super.process(node, context);
        }

        @Override
        protected Type visitSymbolReference(SymbolReference node, Context context) {
            Symbol symbol = Symbol.from(node);
            Type type = context.getArgumentTypes().get(symbol);
            if (type == null) {
                type = this.symbolTypes.getTableModelType(symbol);
            }
            Preconditions.checkArgument((type != null ? 1 : 0) != 0, (String)"No type for: %s", (Object)node.getName());
            return this.setExpressionType(node, type);
        }

        @Override
        protected Type visitNotExpression(NotExpression node, Context context) {
            this.process((Node)node.getValue(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitLogicalExpression(LogicalExpression node, Context context) {
            node.getTerms().forEach(term -> this.process((Node)term, context));
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitComparisonExpression(ComparisonExpression node, Context context) {
            this.process((Node)node.getLeft(), context);
            this.process((Node)node.getRight(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitIsNullPredicate(IsNullPredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitIsNotNullPredicate(IsNotNullPredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitNullIfExpression(NullIfExpression node, Context context) {
            Type firstType = this.process((Node)node.getFirst(), context);
            Type ignored = this.process((Node)node.getSecond(), context);
            return this.setExpressionType(node, firstType);
        }

        @Override
        protected Type visitIfExpression(IfExpression node, Context context) {
            Type conditionType = this.process((Node)node.getCondition(), context);
            Preconditions.checkArgument((boolean)conditionType.equals(BooleanType.BOOLEAN), (String)"Condition must be boolean: %s", (Object)conditionType);
            Type trueType = this.process((Node)node.getTrueValue(), context);
            if (node.getFalseValue().isPresent()) {
                Type falseType = this.process((Node)node.getFalseValue().get(), context);
                Preconditions.checkArgument((boolean)trueType.equals(falseType), (String)"Types must be equal: %s vs %s", (Object)trueType, (Object)falseType);
            }
            return this.setExpressionType(node, trueType);
        }

        @Override
        protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
            Set resultTypes = node.getWhenClauses().stream().map(clause -> {
                Type operandType = this.process((Node)clause.getOperand(), context);
                Preconditions.checkArgument((boolean)operandType.equals(BooleanType.BOOLEAN), (String)"When clause operand must be boolean: %s", (Object)operandType);
                return this.setExpressionType((Expression)clause, this.process((Node)clause.getResult(), context));
            }).collect(Collectors.toSet());
            Preconditions.checkArgument((resultTypes.size() == 1 ? 1 : 0) != 0, (String)"All result types must be the same: %s", resultTypes);
            Type resultType = (Type)resultTypes.iterator().next();
            node.getDefaultValue().ifPresent(defaultValue -> {
                Type defaultType = this.process((Node)defaultValue, context);
                Preconditions.checkArgument((boolean)defaultType.equals(resultType), (String)"Default result type must be the same as WHEN result types: %s vs %s", (Object)defaultType, (Object)resultType);
            });
            return this.setExpressionType(node, resultType);
        }

        @Override
        protected Type visitSimpleCaseExpression(SimpleCaseExpression node, Context context) {
            Type operandType = this.process((Node)node.getOperand(), context);
            Set resultTypes = node.getWhenClauses().stream().map(clause -> {
                Type clauseOperandType = this.process((Node)clause.getOperand(), context);
                Preconditions.checkArgument((boolean)clauseOperandType.equals(operandType), (String)"WHEN clause operand type must match CASE operand type: %s vs %s", (Object)clauseOperandType, (Object)operandType);
                return this.setExpressionType((Expression)clause, this.process((Node)clause.getResult(), context));
            }).collect(Collectors.toSet());
            Preconditions.checkArgument((resultTypes.size() == 1 ? 1 : 0) != 0, (String)"All result types must be the same: %s", resultTypes);
            Type resultType = (Type)resultTypes.iterator().next();
            node.getDefaultValue().ifPresent(defaultValue -> {
                Type defaultType = this.process((Node)defaultValue, context);
                Preconditions.checkArgument((boolean)defaultType.equals(resultType), (String)"Default result type must be the same as WHEN result types: %s vs %s", (Object)defaultType, (Object)resultType);
            });
            return this.setExpressionType(node, resultType);
        }

        @Override
        protected Type visitCoalesceExpression(CoalesceExpression node, Context context) {
            Set types = node.getOperands().stream().map(operand -> this.process((Node)operand, context)).collect(Collectors.toSet());
            Preconditions.checkArgument((types.size() == 1 ? 1 : 0) != 0, (String)"All operands must have the same type: %s", types);
            return this.setExpressionType(node, (Type)types.iterator().next());
        }

        @Override
        protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, Context context) {
            return this.setExpressionType(node, this.process((Node)node.getValue(), context));
        }

        @Override
        protected Type visitExtract(Extract node, Context context) {
            this.process((Node)node.getExpression(), context);
            return this.setExpressionType(node, (Type)LongType.INT64);
        }

        @Override
        protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) {
            ImmutableList.Builder argumentTypes = ImmutableList.builder();
            argumentTypes.add((Object)this.process((Node)node.getLeft(), context));
            argumentTypes.add((Object)this.process((Node)node.getRight(), context));
            try {
                return this.setExpressionType(node, this.plannerContext.getMetadata().getOperatorReturnType(OperatorType.valueOf(node.getOperator().name()), (List<? extends Type>)argumentTypes.build()));
            }
            catch (OperatorNotFoundException e) {
                throw new SemanticException(e.getMessage());
            }
        }

        @Override
        protected Type visitStringLiteral(StringLiteral node, Context context) {
            return this.setExpressionType(node, (Type)StringType.STRING);
        }

        @Override
        protected Type visitBinaryLiteral(BinaryLiteral node, Context context) {
            return this.setExpressionType(node, (Type)BlobType.BLOB);
        }

        @Override
        protected Type visitLongLiteral(LongLiteral node, Context context) {
            if (node.getParsedValue() >= Integer.MIN_VALUE && node.getParsedValue() <= Integer.MAX_VALUE) {
                return this.setExpressionType(node, (Type)IntType.INT32);
            }
            return this.setExpressionType(node, (Type)LongType.INT64);
        }

        @Override
        protected Type visitDoubleLiteral(DoubleLiteral node, Context context) {
            return this.setExpressionType(node, (Type)DoubleType.DOUBLE);
        }

        @Override
        protected Type visitBooleanLiteral(BooleanLiteral node, Context context) {
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitGenericLiteral(GenericLiteral node, Context context) {
            DateType type;
            if (DateType.DATE.getTypeEnum().name().equals(node.getType())) {
                type = DateType.DATE;
            } else if (TimestampType.TIMESTAMP.getTypeEnum().name().equals(node.getType())) {
                type = TimestampType.TIMESTAMP;
            } else if (LongType.INT64.getTypeEnum().name().equals(node.getType())) {
                type = LongType.INT64;
            } else {
                throw new SemanticException("Unsupported type in GenericLiteral: " + node.getType());
            }
            return this.setExpressionType(node, (Type)type);
        }

        @Override
        protected Type visitNullLiteral(NullLiteral node, Context context) {
            return this.setExpressionType(node, (Type)UnknownType.UNKNOWN);
        }

        @Override
        protected Type visitFunctionCall(FunctionCall node, Context context) {
            ArrayList<Type> argumentTypes = new ArrayList<Type>(node.getArguments().size());
            for (int i = 0; i < node.getArguments().size(); ++i) {
                Expression argument = node.getArguments().get(i);
                argumentTypes.add(this.process((Node)argument, context));
            }
            return this.setExpressionType(node, this.plannerContext.getMetadata().getFunctionReturnType(node.getName().getSuffix(), argumentTypes));
        }

        @Override
        protected Type visitBetweenPredicate(BetweenPredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            this.process((Node)node.getMin(), context);
            this.process((Node)node.getMax(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        public Type visitCast(Cast node, Context context) {
            this.process((Node)node.getExpression(), context);
            return this.setExpressionType(node, this.plannerContext.getTypeManager().getType(TypeSignatureTranslator.toTypeSignature(node.getType())));
        }

        @Override
        protected Type visitInPredicate(InPredicate node, Context context) {
            Expression value = node.getValue();
            InListExpression valueList = (InListExpression)node.getValueList();
            Type type = this.process((Node)value, context);
            for (Expression item : valueList.getValues()) {
                this.process((Node)item, context);
            }
            this.setExpressionType(valueList, type);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitRow(Row node, Context context) {
            List types = (List)node.getItems().stream().map(child -> this.process((Node)child, context)).collect(ImmutableList.toImmutableList());
            return this.setExpressionType(node, (Type)RowType.anonymous((List)types));
        }

        @Override
        protected Type visitLikePredicate(LikePredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            this.process((Node)node.getPattern(), context);
            node.getEscape().ifPresent(e -> this.process((Node)e, context));
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitCurrentDatabase(CurrentDatabase node, Context context) {
            return this.setExpressionType(node, (Type)StringType.STRING);
        }

        @Override
        protected Type visitCurrentUser(CurrentUser node, Context context) {
            return this.setExpressionType(node, (Type)StringType.STRING);
        }

        @Override
        protected Type visitExpression(Expression node, Context context) {
            throw new UnsupportedOperationException("Not a valid IR expression: " + node.getClass().getName());
        }

        @Override
        protected Type visitNode(Node node, Context context) {
            throw new UnsupportedOperationException("Not a valid IR expression: " + node.getClass().getName());
        }
    }

    private static class Context {
        private final Map<Symbol, Type> argumentTypes;

        public Context(Map<Symbol, Type> argumentTypes) {
            this.argumentTypes = argumentTypes;
        }

        public Map<Symbol, Type> getArgumentTypes() {
            return this.argumentTypes;
        }
    }
}

