Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][SPARK-50983][SQL] Support Nested Correlated Subqueries for Analyzer #49660

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5906,6 +5906,11 @@
"Correlated scalar subqueries must be aggregated to return at most one row."
]
},
"NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED" : {
"message" : [
"Nested correlated subqueries are not supported."
]
},
"NON_CORRELATED_COLUMNS_IN_GROUP_BY" : {
"message" : [
"A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns: <value>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2263,6 +2263,37 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* Note: CTEs are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] {

private def getOuterAttrsNeedToBePropagated(plan: LogicalPlan): Seq[Expression] = {
plan.expressions.flatMap {
case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs
case in: InSubquery => in.query.getUnresolvedOuterAttrs
case expr if expr.containsPattern(PLAN_EXPRESSION) =>
expr.collect {
case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs
}.flatten
case _ => Seq.empty
} ++ plan.children.flatMap{
case p if p.containsPattern(PLAN_EXPRESSION) =>
getOuterAttrsNeedToBePropagated(p)
case _ => Seq.empty
}
}

private def getUnresolvedOuterReferences(
s: SubqueryExpression, p: LogicalPlan
): Seq[Expression] = {
val outerReferencesInSubquery = s.getOuterAttrs

// return outer references cannot be handled in current plan
outerReferencesInSubquery.filter(
_ match {
case a: AttributeReference => !p.inputSet.contains(a)
case _ => false
}
)
}

/**
* Resolves the subquery plan that is referenced in a subquery expression, by invoking the
* entire analyzer recursively. We set outer plan in `AnalysisContext`, so that the analyzer
Expand All @@ -2274,15 +2305,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
e: SubqueryExpression,
outer: LogicalPlan)(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) {
executeSameContext(e.plan)
val newSubqueryPlan = if (AnalysisContext.get.outerPlan.isDefined) {
val propogatedOuterPlan = AnalysisContext.get.outerPlan.get
AnalysisContext.withOuterPlan(propogatedOuterPlan) {
executeSameContext(e.plan)
}
} else {
AnalysisContext.withOuterPlan(outer) {
executeSameContext(e.plan)
}
}

// If the subquery plan is fully resolved, pull the outer references and record
// them as children of SubqueryExpression.
if (newSubqueryPlan.resolved) {
// Record the outer references as children of subquery expression.
f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan))
f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan) ++
getOuterAttrsNeedToBePropagated(newSubqueryPlan))
} else {
e.withNewPlan(newSubqueryPlan)
}
Expand All @@ -2299,18 +2338,45 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(e, outer)(Exists(_, _, exprId))
// There are four kinds of outer references here:
// 1. Outer references which are newly introduced in the subquery `res`
// which can be resolved in current `plan`.
// It is extracted by `SubExprUtils.getOuterReferences(res.plan)` and
// stored among res.outerAttrs
// 2. Outer references which are newly introduced in the subquery `res`
// which cannot be resolved in current `plan`
// It is extracted by `SubExprUtils.getOuterReferences(res.plan)` with
// `getUnresolvedOuterReferences(res, plan)` filter and stored in
// res.unresolvedOuterAttrs
// 3. Outer references which are introduced by nested subquery within `res.plan`
// which can be resolved in current `plan`
// It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered
// by `plan.inputSet.contains(_)`, need to be stored in res.outerAttrs
// 4. Outer references which are introduced by nested subquery within `res.plan`
// which cannot be resolved in current `plan`
// It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered
// by `!plan.inputSet.contains(_)`, need to be stored in
// res.outerAttrs and res.unresolvedOuterAttrs
case s @ ScalarSubquery(sub, _, _, exprId, _, _, _, _) if !sub.resolved =>
val res = resolveSubQuery(s, outer)(ScalarSubquery(_, _, Seq.empty, exprId))
val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan)
res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
case e @ Exists(sub, _, _, exprId, _, _) if !sub.resolved =>
val res = resolveSubQuery(e, outer)(Exists(_, _, Seq.empty, exprId))
val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan)
res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output.length)
})
InSubquery(values, expr.asInstanceOf[ListQuery])
case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
ListQuery(plan, exprs, Seq.empty, exprId, plan.output.length)
}).asInstanceOf[ListQuery]
val unresolvedOuterReferences = getUnresolvedOuterReferences(expr, plan)
val newExpr = expr.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
InSubquery(values, newExpr)
case s @ LateralSubquery(sub, _, _, exprId, _, _) if !sub.resolved =>
val res = resolveSubQuery(s, outer)(LateralSubquery(_, _, Seq.empty, exprId))
val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan)
res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
case a: FunctionTableSubqueryArgumentExpression if !a.plan.resolved =>
resolveSubQuery(a, outer)(
(plan, outerAttrs) => a.copy(plan = plan, outerAttrs = outerAttrs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,32 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}
}

def checkNoUnresolvedOuterReferencesInMainQuery(plan: LogicalPlan): Unit = {
plan.expressions.foreach {
case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty =>
subExpr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND",
messageParameters = Map.empty)
case in: InSubquery if in.query.getUnresolvedOuterAttrs.nonEmpty =>
in.query.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND",
messageParameters = Map.empty)
case expr if expr.containsPattern(PLAN_EXPRESSION) =>
expr.collect {
case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty =>
subExpr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND",
messageParameters = Map.empty)
}
case _ =>
}
plan.children.foreach {
case p: LogicalPlan if p.containsPattern(PLAN_EXPRESSION) =>
checkNoUnresolvedOuterReferencesInMainQuery(p)
case _ =>
}
}

def checkAnalysis(plan: LogicalPlan): Unit = {
// We should inline all CTE relations to restore the original plan shape, as the analysis check
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
Expand All @@ -244,6 +270,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}
preemptedError.clear()
try {
checkNoUnresolvedOuterReferencesInMainQuery(inlinedPlan)
checkAnalysis0(inlinedPlan)
preemptedError.getErrorOpt().foreach(throw _) // throw preempted error if any
} catch {
Expand Down Expand Up @@ -1137,14 +1164,28 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case _ =>
}

def checkUnresolvedOuterReferences(expr: SubqueryExpression): Unit = {
if ((!SQLConf.get.getConf(SQLConf.SUPPORT_NESTED_CORRELATED_SUBQUERIES)) &&
expr.getUnresolvedOuterAttrs.nonEmpty) {
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED",
messageParameters = Map.empty)
}
}

// Check if there are nested correlated subqueries in the plan.
checkUnresolvedOuterReferences(expr)


// Validate the subquery plan.
checkAnalysis0(expr.plan)

// Check if there is outer attribute that cannot be found from the plan.
checkOuterReference(plan, expr)

expr match {
case ScalarSubquery(query, outerAttrs, _, _, _, _, _) =>
case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
Expand Down Expand Up @@ -1354,9 +1395,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// +- Project [c1#87, c2#88]
// : (Aggregate or Window operator)
// : +- Filter [outer(c2#77) >= c2#88)]
// : +- SubqueryAlias t2, `t2`
// : +- Project [_1#84 AS c1#87, _2#85 AS c2#88]
// : +- LocalRelation [_1#84, _2#85]
// : - SubqueryAlias t2, `t2`
// : - Project [_1#84 AS c1#87, _2#85 AS c2#88]
// : - LocalRelation [_1#84, _2#85]
// +- SubqueryAlias t1, `t1`
// +- Project [_1#73 AS c1#76, _2#74 AS c2#77]
// +- LocalRelation [_1#73, _2#74]
Expand All @@ -1373,7 +1414,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// Original subquery plan:
// Aggregate [count(1)]
// +- Filter ((a + b) = outer(c))
// +- LocalRelation [a, b]
//- LocalRelation [a, b]
//
// Plan after pulling up correlated predicates:
// Aggregate [a, b] [count(1), a, b]
Expand All @@ -1383,8 +1424,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
// Project [c1, count(1)]
// +- Join LeftOuter ((a + b) = c)
// :- LocalRelation [c]
// +- Aggregate [a, b] [count(1), a, b]
// +- LocalRelation [a, b]
//- Aggregate [a, b] [count(1), a, b]
// - LocalRelation [a, b]
//
// The right hand side of the join transformed from the subquery will output
// count(1) | a | b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ import org.apache.spark.sql.types.DataType
case class FunctionTableSubqueryArgumentExpression(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
unresolvedOuterAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
partitionByExpressions: Seq[Expression] = Seq.empty,
withSinglePartition: Boolean = false,
orderByExpressions: Seq[SortOrder] = Seq.empty,
selectedInputExpressions: Seq[PythonUDTFSelectedExpression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable {
extends SubqueryExpression(
plan, outerAttrs, unresolvedOuterAttrs, exprId, Seq.empty, None) with Unevaluable {

assert(!(withSinglePartition && partitionByExpressions.nonEmpty),
"WITH SINGLE PARTITION is mutually exclusive with PARTITION BY")
Expand All @@ -83,6 +85,14 @@ case class FunctionTableSubqueryArgumentExpression(
copy(plan = plan)
override def withNewOuterAttrs(outerAttrs: Seq[Expression])
: FunctionTableSubqueryArgumentExpression = copy(outerAttrs = outerAttrs)
override def withNewUnresolvedOuterAttrs(
unresolvedOuterAttrs: Seq[Expression]
): FunctionTableSubqueryArgumentExpression = {
if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) {
// TODO(avery): create suitable error subclass to throw
}
copy(unresolvedOuterAttrs = unresolvedOuterAttrs)
}
override def hint: Option[HintInfo] = None
override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression =
copy()
Expand All @@ -91,6 +101,7 @@ case class FunctionTableSubqueryArgumentExpression(
FunctionTableSubqueryArgumentExpression(
plan.canonicalized,
outerAttrs.map(_.canonicalized),
unresolvedOuterAttrs.map(_.canonicalized),
ExprId(0),
partitionByExpressions,
withSinglePartition,
Expand Down
Loading
Loading