Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
// this batch.
Batch("Early Filter and Projection Push-Down", Once, earlyScanPushDownRules: _*),
Batch("Update CTE Relation Stats", Once, UpdateCTERelationStats),
// This batch pushes Join through Union when the right side is broadcastable.
// It must run after "Early Filter and Projection Push-Down" because it relies on
// accurate stats (e.g., DSv2 relations only report stats after V2ScanRelationPushDown).
Batch("Push Down Join Through Union", Once,
PushDownJoinThroughUnion),
// Since join costs in AQP can change between multiple runs, there is no reason that we have an
// idempotence enforcement on this batch. We thus make it FixedPoint(1) instead of Once.
Batch("Join Reorder", FixedPoint(1),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.DeduplicateRelations
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{JOIN, UNION}
import org.apache.spark.sql.internal.SQLConf

/**
* Pushes down `Join` through `Union` when the right side of the join is small enough
* to broadcast.
*
* This rule transforms the pattern:
* {{{
* Join(Union(c1, c2, ..., cN), right, joinType, cond)
* }}}
* into:
* {{{
* Union(Join(c1, right, joinType, cond1), Join(c2, right, joinType, cond2), ...)
* }}}
*
* where each `condK` has the Union output attributes rewritten to the corresponding child's
* output attributes.
*
* This is beneficial when the right side is small enough to broadcast, because it avoids
* shuffling the (potentially very large) Union result before the Join. Instead, each Union
* branch joins independently with the broadcasted right side.
*
* Applicable join types: Inner, LeftOuter.
*/
object PushDownJoinThroughUnion
extends Rule[LogicalPlan]
with JoinSelectionHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(JOIN, UNION), ruleId) {

case join @ Join(u: Union, right, joinType, joinCond, hint)
if conf.getConf(SQLConf.PUSH_DOWN_JOIN_THROUGH_UNION_ENABLED) &&
(joinType == Inner || joinType == LeftOuter) &&
canPlanAsBroadcastHashJoin(join, conf) &&
// Conservatively exclude any right subtree containing subqueries,
// as DeduplicateRelations may not correctly handle correlated references.
// Non-correlated subqueries are safe in theory but excluded for simplicity.
!right.exists(_.expressions.exists(SubqueryExpression.hasSubquery)) =>

// Each Union branch gets its own independent copy of `right` with fresh
// ExprIds to avoid duplicate ExprIds in the plan tree. The first branch
// reuses the original `right` directly; subsequent branches use the
// "fake self-join + DeduplicateRelations" pattern (same as InlineCTE)
// to clone the subtree.
//
// Note: join condition attributes referencing `right.output` are assumed
// to share the same ExprIds, which holds after the analysis phase.
val unionHeadOutput = u.children.head.output
val newChildren = u.children.zipWithIndex.map { case (child, idx) =>
val newRight = if (idx == 0) right else dedupRight(right)
// For idx == 0, child == u.children.head, so leftRewrites is identity
// and rightRewrites is empty; the condition is used as-is.
val leftRewrites = AttributeMap(unionHeadOutput.zip(child.output))
val rightRewrites = if (idx == 0) {
AttributeMap.empty[Attribute]
} else {
AttributeMap(right.output.zip(newRight.output))
}
val newCond = joinCond.map(_.transform {
case a: Attribute if leftRewrites.contains(a) => leftRewrites(a)
case a: Attribute if rightRewrites.contains(a) => rightRewrites(a)
})
Join(child, newRight, joinType, newCond, hint)
}
u.withNewChildren(newChildren)
}

/**
* Creates a copy of `plan` with fresh ExprIds on all output attributes.
* Uses the same "fake self-join + DeduplicateRelations" pattern as InlineCTE.
*
* This works for any plan whose leaf nodes implement `MultiInstanceRelation`
* (e.g., `LocalRelation`, `LogicalRelation`, `HiveTableRelation`), which covers
* both test and production scenarios. If a leaf node does not implement
* `MultiInstanceRelation` (e.g., some custom data sources), `DeduplicateRelations`
* will not refresh its ExprIds. Such cases are rare in practice.
*/
private def dedupRight(plan: LogicalPlan): LogicalPlan = {
DeduplicateRelations(
Join(plan, plan, Inner, None, JoinHint.NONE)
) match {
case Join(_, deduped, _, _, _) => deduped
case other =>
throw SparkException.internalError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any other optimization through bug-like errors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yaooqinn. Yes, SparkException.internalError is used in several optimizer rules as a defensive guard for "should-never-happen" plan shapes, for example:

  • NestedColumnAliasing: "Unreasonable plan after optimization: $other"
  • PushExtraPredicateThroughJoin / Optimizer: "Unexpected join type: $other"
  • DecorrelateInnerQuery: "Unexpected domain join type $o"
  • subquery.scala: "Unexpected plan when optimizing one row relation subquery: $o"

The dedupRight method here follows the same pattern — it guards against the (theoretically impossible) case where DeduplicateRelations changes the Join plan shape.

That said, InlineCTE uses the same "fake self-join + DeduplicateRelations" approach and simply calls .children(1) directly without any defensive check. I can align with InlineCTE and remove the explicit throw if you think that's cleaner. Alternatively, I could keep the pattern match but return the original plan unchanged in the fallback case (skipping the dedup rather than failing). Which approach would you prefer?

s"Unexpected plan shape after DeduplicateRelations: ${other.getClass.getName}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
"org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" ::
"org.apache.spark.sql.catalyst.optimizer.PruneFilters" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownJoinThroughUnion" ::
"org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" ::
"org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val PUSH_DOWN_JOIN_THROUGH_UNION_ENABLED =
buildConf("spark.sql.optimizer.pushDownJoinThroughUnion.enabled")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it needs to be set to false by default, please let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for true by default because this configuration is only a safe-guard for any future regression.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the code, we can use spark.sql.optimizer.excludedRules instead of this, right? Is there any difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point @dongjoon-hyun. You're right — spark.sql.optimizer.excludedRules already provides a general mechanism to disable any optimizer rule, and adding a dedicated config for each rule would lead to config proliferation. I'll remove the dedicated config spark.sql.optimizer.pushDownJoinThroughUnion.enabled and rely on excludedRules instead. Thanks for the suggestion!

.doc("When true, pushes down Join through Union when the join's right side " +
"is small enough to broadcast, avoiding shuffling the Union result.")
.version("4.2.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefault(true)

val DYNAMIC_PARTITION_PRUNING_USE_STATS =
buildConf("spark.sql.optimizer.dynamicPartitionPruning.useStats")
.internal()
Expand Down
Loading