Skip to content

Commit

Permalink
[SPARK-50990][SQL] Refactor UpCast resolution out of the Analyzer
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Refactor `UpCast` resolution out of the `Analyzer`.

### Why are the changes needed?

To reuse this code in single-pass `Resolver`.

### Does this PR introduce _any_ user-facing change?

No, just a refactoring.

### How was this patch tested?

Existing tests.

### Was this patch authored or co-authored using generative AI tooling?

Copilot.nvim.

Closes #49669 from vladimirg-db/vladimirg-db/refactor-upcast-resolution-out.

Authored-by: Vladimir Golubev <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
vladimirg-db authored and MaxGekk committed Jan 26, 2025
1 parent 762599c commit d0b1b0b
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3727,45 +3727,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
*/
object ResolveUpCast extends Rule[LogicalPlan] {
private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
val fromStr = from match {
case l: LambdaVariable => "array element"
case e => e.sql
}
throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath)
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(UP_CAST), ruleId) {
case p if !p.childrenResolved => p
case p if p.resolved => p

case p => p.transformExpressionsWithPruning(_.containsPattern(UP_CAST), ruleId) {
case u @ UpCast(child, _, _) if !child.resolved => u

case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] =>
throw SparkException.internalError(
s"UpCast only supports DecimalType as AbstractDataType yet, but got: $target")

case UpCast(child, target, walkedTypePath) if target == DecimalType
&& child.dataType.isInstanceOf[DecimalType] =>
assert(walkedTypePath.nonEmpty,
"object DecimalType should only be used inside ExpressionEncoder")

// SPARK-31750: if we want to upcast to the general decimal type, and the `child` is
// already decimal type, we can remove the `Upcast` and accept any precision/scale.
// This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`.
child

case UpCast(child, target: AtomicType, _)
if conf.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
child.dataType == StringType =>
Cast(child, target.asNullable)

case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) =>
fail(child, u.dataType, walkedTypePath)

case u @ UpCast(child, _, _) => Cast(child, u.dataType)
case unresolvedUpCast @ UpCast(child, _, _) if !child.resolved =>
unresolvedUpCast
case unresolvedUpCast: UpCast =>
UpCastResolution.resolve(unresolvedUpCast)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{AtomicType, DataType, DecimalType, StringType}

object UpCastResolution extends SQLConfHelper {
def resolve(unresolvedUpCast: UpCast): Expression = unresolvedUpCast match {
case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] =>
throw SparkException.internalError(
s"UpCast only supports DecimalType as AbstractDataType yet, but got: $target"
)

case UpCast(child, target, walkedTypePath)
if target == DecimalType
&& child.dataType.isInstanceOf[DecimalType] =>
assert(
walkedTypePath.nonEmpty,
"object DecimalType should only be used inside ExpressionEncoder"
)

// SPARK-31750: if we want to upcast to the general decimal type, and the `child` is
// already decimal type, we can remove the `Upcast` and accept any precision/scale.
// This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`.
child

case UpCast(child, target: AtomicType, _)
if conf.getConf(SQLConf.LEGACY_LOOSE_UPCAST) &&
child.dataType == StringType =>
Cast(child, target.asNullable)

case unresolvedUpCast @ UpCast(child, _, walkedTypePath)
if !Cast.canUpCast(child.dataType, unresolvedUpCast.dataType) =>
fail(child, unresolvedUpCast.dataType, walkedTypePath)

case unresolvedUpCast @ UpCast(child, _, _) =>
Cast(child, unresolvedUpCast.dataType)
}

private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
val fromStr = from match {
case l: LambdaVariable => "array element"
case e => e.sql
}

throw QueryCompilationErrors.upCastFailureError(fromStr, from, to, walkedTypePath)
}
}

0 comments on commit d0b1b0b

Please sign in to comment.