/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.analysis.resolver

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{
  AnsiTypeCoercion,
  CollationTypeCoercion,
  TypeCoercion
}
import org.apache.spark.sql.catalyst.expressions.{Expression, OuterReference, SubExprUtils}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ListAgg}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
 * Resolver for [[AggregateExpressions]] that can come from either [[FunctionResolver]] or
 * [[ExpressionResolver]]. It handles the resolution and validation of [[AggregateExpression]].
 */
class AggregateExpressionResolver(
    operatorResolver: Resolver,
    expressionResolver: ExpressionResolver)
    extends TreeNodeResolver[AggregateExpression, Expression]
    with ResolvesExpressionChildren
    with CoercesExpressionTypes {

  private val traversals = expressionResolver.getExpressionTreeTraversals

  protected override val ansiTransformations: CoercesExpressionTypes.Transformations =
    AggregateExpressionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS
  protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations =
    AggregateExpressionResolver.TYPE_COERCION_TRANSFORMATIONS

  private val expressionResolutionContextStack =
    expressionResolver.getExpressionResolutionContextStack
  private val subqueryRegistry = operatorResolver.getSubqueryRegistry
  private val autoGeneratedAliasProvider = new AutoGeneratedAliasProvider(
    expressionResolver.getExpressionIdAssigner
  )

  /**
   * Resolves the given [[AggregateExpression]] originating from [[ExpressionResolver]] by
   * resolving its children recursively and validating the resolved expression.
   */
  override def resolve(aggregateExpression: AggregateExpression): Expression = {
    val aggregateExpressionWithChildrenResolved =
      withResolvedChildren(aggregateExpression, expressionResolver.resolve _)
        .asInstanceOf[AggregateExpression]
    handleAggregateExpressionWithChildrenResolved(aggregateExpressionWithChildrenResolved)
  }

  /**
   * Resolves the given [[AggregateExpression]] originating from [[FunctionResolver]] by applying
   * type coercion to its children and validating the resolved expression. In this case, it is not
   * necessary to resolve the children recursively, as they were already resolved in
   * [[FunctionResolver]].
   */
  def resolveWithoutRecursingIntoChildren(aggregateExpression: AggregateExpression): Expression = {
    val aggregateExpressionWithTypeCoercedChildren = aggregateExpression
      .mapChildren(
        expression => coerceExpressionTypes(expression, traversals.current)
      )
      .asInstanceOf[AggregateExpression]
    handleAggregateExpressionWithChildrenResolved(aggregateExpressionWithTypeCoercedChildren)
  }

  /**
   * Handles resolution and validation of the [[AggregateExpression]] after its children have been
   * resolved:
   *  - Resolution:
   *    1. Update the [[ExpressionResolver.expressionResolutionContextStack]];
   *    2. Handle [[OuterReference]] in [[AggregateExpression]], if there are any (see
   *    `handleOuterAggregateExpression`);
   *  - Validation:
   *   1. [[ListAgg]] is not allowed in DISTINCT aggregates if it contains [[SortOrder]] different
   *      from its child;
   *   2. Nested aggregate functions are not allowed;
   *   3. Nondeterministic expressions in the subtree of a related aggregate function are not
   *      allowed;
   *   4. The mix of outer and local references is not allowed;
   */
  private def handleAggregateExpressionWithChildrenResolved(
      aggregateExpressionWithChildrenResolved: AggregateExpression): Expression = {
    val expressionResolutionContext = expressionResolutionContextStack.peek()

    validateResolvedAggregateExpression(aggregateExpressionWithChildrenResolved)

    expressionResolutionContext.hasAggregateExpressions = true

    // There are two different cases that we handle regarding the value of the flag:
    //
    //   - We have an attribute under an `AggregateExpression`:
    //       {{{ SELECT COUNT(col1) FROM VALUES (1); }}}
    //     In this case, value of the `hasAttributeOutsideOfAggregateExpressions` flag should be
    //     `false` as it indicates whether there is an attribute in the subtree that's not
    //     `AggregateExpression` so we can throw the `MISSING_GROUP_BY` exception appropriately.
    //
    //   - In the following example:
    //       {{{ SELECT COUNT(*), col1 + 1 FROM VALUES (1); }}}
    //     It would be `true` as described above.
    expressionResolutionContext.hasAttributeOutsideOfAggregateExpressions = false

    if (expressionResolutionContext.hasOuterReferences) {
      handleOuterAggregateExpression(aggregateExpressionWithChildrenResolved)
    } else {
      aggregateExpressionWithChildrenResolved
    }
  }

  private def validateResolvedAggregateExpression(aggregateExpression: AggregateExpression): Unit =
    aggregateExpression match {
      case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
          if agg.isDistinct && listAgg.needSaveOrderValue =>
        throwFunctionAndOrderExpressionMismatchError(listAgg)
      case _ =>
        if (expressionResolutionContextStack.peek().hasAggregateExpressions) {
          throwNestedAggregateFunction(aggregateExpression)
        }

        val nonDeterministicChild =
          aggregateExpression.aggregateFunction.children.collectFirst {
            case child if !child.deterministic => child
          }
        if (nonDeterministicChild.nonEmpty) {
          throwAggregateFunctionWithNondeterministicExpression(
            aggregateExpression,
            nonDeterministicChild.get
          )
        }
    }

  /**
   * If the [[AggregateExpression]] has outer references in its subtree, we need to handle it in a
   * special way. The whole process is explained in the [[SubqueryScope]] scaladoc, but in short
   * we need to:
   *  - Validate that we don't have local references in this subtree;
   *  - Create a new subtree without [[OuterReference]]s;
   *  - Alias this subtree and put it inside the current [[SubqueryScope]];
   *  - If outer aggregates are allowed, replace the [[AggregateExpression]] with an
   *    [[OuterReference]] to the auto-generated [[Alias]] that we created in case the subtree
   *    without [[OuterReference]]s can't be found in the outer
   *    [[Aggregate.aggregateExpressions]] list. Otherwise, use the [[Alias]] from the outer
   *    [[Aggregate]]. This alias will later be injected into the outer [[Aggregate]];
   *  - Store the name that needs to be used for the [[OuterReference]] in
   *    [[OuterReference.SINGLE_PASS_SQL_STRING_OVERRIDE]] computed based on the
   *    [[AggregateExpression]] without [[OuterReference]] pulled out.
   *  - In case we have an [[AggregateExpression]] inside a [[Sort]] operator, we need to handle it
   *    in a special way (see [[handleAggregateExpressionOutsideAggregate]] for more details).
   *  - Return the original [[AggregateExpression]] otherwise. This is done to stay compatible
   *    with the fixed-point Analyzer - a proper exception will be thrown later by
   *    [[ValidateSubqueryExpression]].
   */
  private def handleOuterAggregateExpression(
      aggregateExpression: AggregateExpression): Expression = {
    if (expressionResolutionContextStack.peek().hasLocalReferences) {
      throw QueryCompilationErrors.mixedRefsInAggFunc(
        aggregateExpression.sql,
        aggregateExpression.origin
      )
    }

    val resolvedOuterAggregateExpression =
      if (subqueryRegistry.currentScope.aggregateExpressionsExtractor.isDefined) {
        extractOuterAggregateExpression(
          aggregateExpression = aggregateExpression,
          aggregateExpressionsExtractor =
            subqueryRegistry.currentScope.aggregateExpressionsExtractor.get
        )
      } else {
        aggregateExpression
      }

    resolvedOuterAggregateExpression match {
      case outerReference: OuterReference =>
        outerReference.setTagValue(
          OuterReference.SINGLE_PASS_SQL_STRING_OVERRIDE,
          toPrettySQL(aggregateExpression)
        )
        outerReference
      case other => other
    }
  }

  private def extractOuterAggregateExpression(
      aggregateExpression: AggregateExpression,
      aggregateExpressionsExtractor: GroupingAndAggregateExpressionsExtractor): OuterReference = {
    val aggregateExpressionWithStrippedOuterReferences =
      SubExprUtils.stripOuterReference(aggregateExpression)

    val outerAggregateExpressionAlias = autoGeneratedAliasProvider.newOuterAlias(
      child = aggregateExpressionWithStrippedOuterReferences
    )

    val (_, referencedAggregateExpressionAlias) =
      aggregateExpressionsExtractor.collectFirstAggregateExpression(
        aggregateExpressionWithStrippedOuterReferences
      )

    referencedAggregateExpressionAlias match {
      case Some(alias) =>
        subqueryRegistry.currentScope.addAliasForOuterAggregateExpression(alias)
        OuterReference(alias.toAttribute)
      case None =>
        subqueryRegistry.currentScope.addAliasForOuterAggregateExpression(
          outerAggregateExpressionAlias
        )
        OuterReference(outerAggregateExpressionAlias.toAttribute)
    }
  }

  private def throwFunctionAndOrderExpressionMismatchError(listAgg: ListAgg) = {
    throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
      listAgg.prettyName,
      listAgg.child,
      listAgg.orderExpressions
    )
  }

  private def throwNestedAggregateFunction(aggregateExpression: AggregateExpression): Nothing = {
    throw new AnalysisException(
      errorClass = "NESTED_AGGREGATE_FUNCTION",
      messageParameters = Map.empty,
      origin = aggregateExpression.origin
    )
  }

  private def throwAggregateFunctionWithNondeterministicExpression(
      aggregateExpression: AggregateExpression,
      nonDeterministicChild: Expression): Nothing = {
    throw new AnalysisException(
      errorClass = "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION",
      messageParameters = Map("sqlExpr" -> toSQLExpr(aggregateExpression)),
      origin = nonDeterministicChild.origin
    )
  }
}

object AggregateExpressionResolver {
  // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]].
  private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq(
    CollationTypeCoercion.apply,
    TypeCoercion.InTypeCoercion.apply,
    TypeCoercion.FunctionArgumentTypeCoercion.apply,
    TypeCoercion.IfTypeCoercion.apply,
    TypeCoercion.ImplicitTypeCoercion.apply
  )

  // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]].
  private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq(
    CollationTypeCoercion.apply,
    AnsiTypeCoercion.InTypeCoercion.apply,
    AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply,
    AnsiTypeCoercion.IfTypeCoercion.apply,
    AnsiTypeCoercion.ImplicitTypeCoercion.apply
  )
}
