Skip to content
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5906,6 +5906,11 @@
"Correlated scalar subqueries must be aggregated to return at most one row."
]
},
"NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED" : {
"message" : [
"Nested correlated subqueries are not supported."
]
},
"NON_CORRELATED_COLUMNS_IN_GROUP_BY" : {
"message" : [
"A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns: <value>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2263,6 +2263,37 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
* Note: CTEs are handled in CTESubstitution.
*/
object ResolveSubquery extends Rule[LogicalPlan] {

private def getOuterAttrsNeedToBePropagated(plan: LogicalPlan): Seq[Expression] = {
plan.expressions.flatMap {
case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs
case in: InSubquery => in.query.getUnresolvedOuterAttrs
case expr if expr.containsPattern(PLAN_EXPRESSION) =>
expr.collect {
case subExpr: SubqueryExpression => subExpr.getUnresolvedOuterAttrs
}.flatten
case _ => Seq.empty
} ++ plan.children.flatMap{
case p if p.containsPattern(PLAN_EXPRESSION) =>
getOuterAttrsNeedToBePropagated(p)
case _ => Seq.empty
}
}

private def getUnresolvedOuterReferences(
s: SubqueryExpression, p: LogicalPlan
): Seq[Expression] = {
val outerReferencesInSubquery = s.getOuterAttrs

// return outer references cannot be handled in current plan
outerReferencesInSubquery.filter(
_ match {
case a: AttributeReference => !p.inputSet.contains(a)
case _ => false
}
)
}

/**
* Resolves the subquery plan that is referenced in a subquery expression, by invoking the
* entire analyzer recursively. We set outer plan in `AnalysisContext`, so that the analyzer
Expand All @@ -2274,15 +2305,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
e: SubqueryExpression,
outer: LogicalPlan)(
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
val newSubqueryPlan = AnalysisContext.withOuterPlan(outer) {
executeSameContext(e.plan)
val newSubqueryPlan = if (AnalysisContext.get.outerPlan.isDefined) {
val propogatedOuterPlan = AnalysisContext.get.outerPlan.get
AnalysisContext.withOuterPlan(propogatedOuterPlan) {
executeSameContext(e.plan)
}
} else {
AnalysisContext.withOuterPlan(outer) {
executeSameContext(e.plan)
}
}

// If the subquery plan is fully resolved, pull the outer references and record
// them as children of SubqueryExpression.
if (newSubqueryPlan.resolved) {
// Record the outer references as children of subquery expression.
f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan))
f(newSubqueryPlan, SubExprUtils.getOuterReferences(newSubqueryPlan) ++
getOuterAttrsNeedToBePropagated(newSubqueryPlan))
} else {
e.withNewPlan(newSubqueryPlan)
}
Expand All @@ -2299,18 +2338,45 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
*/
private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = {
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) {
case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(e, outer)(Exists(_, _, exprId))
case InSubquery(values, l @ ListQuery(_, _, exprId, _, _, _))
// There are four kinds of outer references here:
// 1. Outer references which are newly introduced in the subquery `res`
// which can be resolved in current `plan`.
// It is extracted by `SubExprUtils.getOuterReferences(res.plan)` and
// stored among res.outerAttrs
// 2. Outer references which are newly introduced in the subquery `res`
// which cannot be resolved in current `plan`
// It is extracted by `SubExprUtils.getOuterReferences(res.plan)` with
// `getUnresolvedOuterReferences(res, plan)` filter and stored in
// res.unresolvedOuterAttrs
// 3. Outer references which are introduced by nested subquery within `res.plan`
// which can be resolved in current `plan`
// It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered
// by `plan.inputSet.contains(_)`, need to be stored in res.outerAttrs
// 4. Outer references which are introduced by nested subquery within `res.plan`
// which cannot be resolved in current `plan`
// It is extracted by `getOuterAttrsNeedToBePropagated(res.plan)`, filtered
// by `!plan.inputSet.contains(_)`, need to be stored in
// res.outerAttrs and res.unresolvedOuterAttrs
case s @ ScalarSubquery(sub, _, _, exprId, _, _, _, _) if !sub.resolved =>
val res = resolveSubQuery(s, outer)(ScalarSubquery(_, _, Seq.empty, exprId))
val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan)
res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
case e @ Exists(sub, _, _, exprId, _, _) if !sub.resolved =>
val res = resolveSubQuery(e, outer)(Exists(_, _, Seq.empty, exprId))
val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan)
res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
case InSubquery(values, l @ ListQuery(_, _, _, exprId, _, _, _))
if values.forall(_.resolved) && !l.resolved =>
val expr = resolveSubQuery(l, outer)((plan, exprs) => {
ListQuery(plan, exprs, exprId, plan.output.length)
})
InSubquery(values, expr.asInstanceOf[ListQuery])
case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved =>
resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
ListQuery(plan, exprs, Seq.empty, exprId, plan.output.length)
}).asInstanceOf[ListQuery]
val unresolvedOuterReferences = getUnresolvedOuterReferences(expr, plan)
val newExpr = expr.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
InSubquery(values, newExpr)
case s @ LateralSubquery(sub, _, _, exprId, _, _) if !sub.resolved =>
val res = resolveSubQuery(s, outer)(LateralSubquery(_, _, Seq.empty, exprId))
val unresolvedOuterReferences = getUnresolvedOuterReferences(res, plan)
res.withNewUnresolvedOuterAttrs(unresolvedOuterReferences)
case a: FunctionTableSubqueryArgumentExpression if !a.plan.resolved =>
resolveSubQuery(a, outer)(
(plan, outerAttrs) => a.copy(plan = plan, outerAttrs = outerAttrs))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,32 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}
}

def checkNoUnresolvedOuterReferencesInMainQuery(plan: LogicalPlan): Unit = {
plan.expressions.foreach {
case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty =>
subExpr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND",
messageParameters = Map.empty)
case in: InSubquery if in.query.getUnresolvedOuterAttrs.nonEmpty =>
in.query.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND",
messageParameters = Map.empty)
case expr if expr.containsPattern(PLAN_EXPRESSION) =>
expr.collect {
case subExpr: SubqueryExpression if subExpr.getUnresolvedOuterAttrs.nonEmpty =>
subExpr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_COLUMN_NOT_FOUND",
messageParameters = Map.empty)
}
case _ =>
}
plan.children.foreach {
case p: LogicalPlan if p.containsPattern(PLAN_EXPRESSION) =>
checkNoUnresolvedOuterReferencesInMainQuery(p)
case _ =>
}
}

def checkAnalysis(plan: LogicalPlan): Unit = {
// We should inline all CTE relations to restore the original plan shape, as the analysis check
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
Expand All @@ -244,6 +270,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
}
preemptedError.clear()
try {
checkNoUnresolvedOuterReferencesInMainQuery(inlinedPlan)
checkAnalysis0(inlinedPlan)
preemptedError.getErrorOpt().foreach(throw _) // throw preempted error if any
} catch {
Expand Down Expand Up @@ -1137,14 +1164,28 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
case _ =>
}

def checkUnresolvedOuterReferences(expr: SubqueryExpression): Unit = {
if ((!SQLConf.get.getConf(SQLConf.SUPPORT_NESTED_CORRELATED_SUBQUERIES)) &&
expr.getUnresolvedOuterAttrs.nonEmpty) {
expr.failAnalysis(
errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." +
"NESTED_CORRELATED_SUBQUERIES_NOT_SUPPORTED",
messageParameters = Map.empty)
}
}

// Check if there are nested correlated subqueries in the plan.
checkUnresolvedOuterReferences(expr)


// Validate the subquery plan.
checkAnalysis0(expr.plan)

// Check if there is outer attribute that cannot be found from the plan.
checkOuterReference(plan, expr)

expr match {
case ScalarSubquery(query, outerAttrs, _, _, _, _, _) =>
case ScalarSubquery(query, outerAttrs, _, _, _, _, _, _) =>
// Scalar subquery must return one column as output.
if (query.output.size != 1) {
throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,38 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
val outerPlan = AnalysisContext.get.outerPlan
if (outerPlan.isEmpty) return e

def resolve(nameParts: Seq[String]): Option[Expression] = try {
def findNestedSubqueryPlans(p: LogicalPlan): Seq[LogicalPlan] = {
if (!p.containsPattern(PLAN_EXPRESSION)) {
// There are no nested subquery plans in the current plan,
// stop searching for its children plan
return Seq.empty
}

val subqueriesInThisNode: Seq[SubqueryExpression] =
p.expressions.flatMap(_.collect {
case in: InSubquery => in.query
case s: SubqueryExpression => s
})

val subqueryPlansFromExpressions: Seq[LogicalPlan] =
subqueriesInThisNode.flatMap(s => findNestedSubqueryPlans(s.plan) :+ s.plan)

val subqueryPlansFromChildren: Seq[LogicalPlan] =
p.children.flatMap(findNestedSubqueryPlans)

// Subquery plan in more inner position gets collected first
// As it is more near the position of the outer reference, it is more likely to have
// the original attributes.
// Though as there are no conflicts, the order does not affect correctness.
subqueryPlansFromChildren ++ subqueryPlansFromExpressions
}

// The passed in `outerPlan` is the outermost plan
// Outer references can be from the `outerPlan` or any of its nested subquery plans
// So we need to try resolving the outer references by using all the plans
val outerPlans = Seq(outerPlan.get) ++ findNestedSubqueryPlans(outerPlan.get)

def resolve(nameParts: Seq[String], outerPlan: Option[LogicalPlan]): Option[Expression] = try {
outerPlan.get match {
// Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions.
// We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will
Expand All @@ -240,14 +271,21 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
None
}

e.transformWithPruning(
_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) {
case u: UnresolvedAttribute =>
resolve(u.nameParts).getOrElse(u)
val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) =>
// If we've already resolved, keep that; otherwise try this plan
acc.orElse(resolve(u.nameParts, Some(plan)))
}
maybeResolved.getOrElse(u)
// Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with
// Aggregate but failed.
case t: TempResolvedColumn if t.hasTried =>
resolve(t.nameParts).getOrElse(t)
val maybeResolved = outerPlans.foldLeft(Option.empty[Expression]) { (acc, plan) =>
// If we've already resolved, keep that; otherwise try this plan
acc.orElse(resolve(t.nameParts, Some(plan)))
}
maybeResolved.getOrElse(t)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ case class SQLFunction(
case (None, Some(Project(expr :: Nil, _: OneRowRelation)))
if !isTableFunc =>
(Some(expr), None)
case (Some(ScalarSubquery(Project(expr :: Nil, _: OneRowRelation), _, _, _, _, _, _)), None)
case (Some(ScalarSubquery(
Project(expr :: Nil, _: OneRowRelation), _, _, _, _, _, _, _)), None)
if !isTableFunc =>
(Some(expr), None)
case (_, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class DynamicPruningSubquery(
onlyInBroadcast: Boolean,
exprId: ExprId = NamedExpression.newExprId,
hint: Option[HintInfo] = None)
extends SubqueryExpression(buildQuery, Seq(pruningKey), exprId, Seq.empty, hint)
extends SubqueryExpression(buildQuery, Seq(pruningKey), Seq.empty, exprId, Seq.empty, hint)
with DynamicPruning
with Unevaluable
with UnaryLike[Expression] {
Expand All @@ -67,6 +67,14 @@ case class DynamicPruningSubquery(
copy()
}

override def withNewUnresolvedOuterAttrs(
unresolvedOuterAttrs: Seq[Expression]
): DynamicPruningSubquery = {
// TODO(avery): create suitable error subclass to throw
// DynamicPruningSubquery should not have this method called on it.
return this
}

override def withNewHint(hint: Option[HintInfo]): SubqueryExpression = copy(hint = hint)

override lazy val resolved: Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ import org.apache.spark.sql.types.DataType
case class FunctionTableSubqueryArgumentExpression(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
unresolvedOuterAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
partitionByExpressions: Seq[Expression] = Seq.empty,
withSinglePartition: Boolean = false,
orderByExpressions: Seq[SortOrder] = Seq.empty,
selectedInputExpressions: Seq[PythonUDTFSelectedExpression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable {
extends SubqueryExpression(
plan, outerAttrs, unresolvedOuterAttrs, exprId, Seq.empty, None) with Unevaluable {

assert(!(withSinglePartition && partitionByExpressions.nonEmpty),
"WITH SINGLE PARTITION is mutually exclusive with PARTITION BY")
Expand All @@ -83,6 +85,14 @@ case class FunctionTableSubqueryArgumentExpression(
copy(plan = plan)
override def withNewOuterAttrs(outerAttrs: Seq[Expression])
: FunctionTableSubqueryArgumentExpression = copy(outerAttrs = outerAttrs)
override def withNewUnresolvedOuterAttrs(
unresolvedOuterAttrs: Seq[Expression]
): FunctionTableSubqueryArgumentExpression = {
if (!unresolvedOuterAttrs.forall(outerAttrs.contains(_))) {
// TODO(avery): create suitable error subclass to throw
}
copy(unresolvedOuterAttrs = unresolvedOuterAttrs)
}
override def hint: Option[HintInfo] = None
override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression =
copy()
Expand All @@ -91,6 +101,7 @@ case class FunctionTableSubqueryArgumentExpression(
FunctionTableSubqueryArgumentExpression(
plan.canonicalized,
outerAttrs.map(_.canonicalized),
unresolvedOuterAttrs.map(_.canonicalized),
ExprId(0),
partitionByExpressions,
withSinglePartition,
Expand Down
Loading