Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -379,36 +379,41 @@ case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {
originalPartitionValues: Seq[InternalRow] = Seq.empty,
isPartiallyClustered: Boolean = false) extends HashPartitioningLike {

// See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to
// super.satisfies0(), because HashPartitioningLike.satisfies0() also matches
// ClusteredDistribution and returns true, which would short-circuit the
// isPartiallyClustered guard.
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same clustering keys of
Copy link
Contributor

@peter-toth peter-toth Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that you deleted 3 comments, but this change should be just a reorder of conditions. Can you please put those comments back?
Also, it seems those comments were deleted in #54751 as well, can you please restore them on branch-4.1 in a follow-up PR?

// `ClusteredDistribution`.
c.areAllClusterKeysMatched(expressions)
required match {
case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
if (isPartiallyClustered) {
false
} else if (requireAllClusterKeys) {
// Checks whether this partitioning is partitioned on exactly same clustering keys of
// `ClusteredDistribution`.
c.areAllClusterKeysMatched(expressions)
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
}

case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
o.areAllClusterKeysMatched(expressions)
case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
o.areAllClusterKeysMatched(expressions)

case _ =>
false
}
case _ =>
super.satisfies0(required)
}
}

Expand All @@ -420,7 +425,7 @@ case class KeyGroupedPartitioning(
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
partitionValues, originalPartitionValues)
partitionValues, originalPartitionValues, isPartiallyClustered)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
Expand All @@ -435,15 +440,16 @@ case class KeyGroupedPartitioning(
}

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(expressions = newChildren)
copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered)
}

object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
originalPartitionValues: Seq[InternalRow],
isPartiallyClustered: Boolean): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
Expand All @@ -455,7 +461,7 @@ object KeyGroupedPartitioning {
.map(_.row)

KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
finalPartitionValues, projectedOriginalPartitionValues)
finalPartitionValues, projectedOriginalPartitionValues, isPartiallyClustered)
}

def project(
Expand Down Expand Up @@ -823,7 +829,10 @@ case class KeyGroupedShuffleSpec(
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
distribution.clustering.length == otherDistribution.clustering.length &&
// SPARK-55848: partially-clustered partitioning is not compatible for SPJ
!partitioning.isPartiallyClustered &&
!otherPartitioning.isPartiallyClustered &&
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ case class BatchScanExec(
}
}
k.copy(expressions = projectedExpressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
partitionValues = newPartValues,
isPartiallyClustered = spjParams.applyPartialClustering)
case p => p
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ case class EnsureRequirements(

private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = {
(plan.outputPartitioning, distribution) match {
case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _),
case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _),
d @ OrderedDistribution(ordering)) if p.satisfies(d) =>
val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute])
val partitionOrdering: Ordering[InternalRow] = {
Expand Down Expand Up @@ -325,12 +325,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case k @ KeyGroupedPartitioning(expressions, n, _, _) =>
case k @ KeyGroupedPartitioning(expressions, n, _, _, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
Expand Down Expand Up @@ -397,7 +397,7 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _, _) =>
case KeyGroupedPartitioning(expressions, _, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ abstract class DistributionAndOrderingSuiteBase
plan: QueryPlan[T]): Partitioning = partitioning match {
case HashPartitioning(exprs, numPartitions) =>
HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions)
case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) =>
case KeyGroupedPartitioning(clustering, numPartitions, partValues,
originalPartValues, isPartiallyClustered) =>
KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues,
originalPartValues)
originalPartValues, isPartiallyClustered)
case PartitioningCollection(partitionings) =>
PartitioningCollection(partitionings.map(resolvePartitioning(_, plan)))
case RangePartitioning(ordering, numPartitions) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.distributions.Distributions
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.expressions.Expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
Expand Down Expand Up @@ -93,13 +93,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {

checkQueryPlan(df, catalystDistribution,
physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions,
partitionValues, partitionValues))
partitionValues, partitionValues, isPartiallyClustered = false))

// multiple group keys should work too as long as partition keys are subset of them
df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts")
checkQueryPlan(df, catalystDistribution,
physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions,
partitionValues, partitionValues))
partitionValues, partitionValues, isPartiallyClustered = false))
}

test("non-clustered distribution: no partition") {
Expand Down Expand Up @@ -2747,4 +2747,148 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
Row("ccc", 30, 400.50)))
}
}

test("SPARK-55848: dropDuplicates after SPJ with partial clustering") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// dropDuplicates on the join key after a partially-clustered SPJ must still
// produce the correct number of distinct ids. Before the fix, the
// partially-clustered partitioning was incorrectly treated as satisfying
// ClusteredDistribution, so EnsureRequirements did not insert an Exchange
// before the dedup, leading to duplicate rows.
val df = sql(
s"""
|SELECT DISTINCT i.id
|FROM testcat.ns.$items i
|JOIN testcat.ns.$purchases p ON i.id = p.item_id
|""".stripMargin)
checkAnswer(df, Seq(Row(1), Row(2), Row(3)))

val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(allShuffles.nonEmpty,
"should contain a shuffle for the post-join dedup with partial clustering")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans.exists(_.outputPartitioning match {
case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered
case _ => false
}), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning")
}
}

test("SPARK-55848: Window dedup after SPJ with partial clustering") {
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// Use ROW_NUMBER() OVER to dedup joined rows per id after a partially-clustered
// SPJ. The WINDOW operator requires ClusteredDistribution on i.id; with partial
// clustering the plan must insert a shuffle so that the window
// produces exactly one row per id.
val df = sql(
s"""
|SELECT id, price FROM (
| SELECT i.id, i.price,
| ROW_NUMBER() OVER (PARTITION BY i.id ORDER BY i.price DESC) AS rn
| FROM testcat.ns.$items i
| JOIN testcat.ns.$purchases p ON i.id = p.item_id
|) t WHERE rn = 1
|""".stripMargin)
checkAnswer(df, Seq(Row(1, 41.0f), Row(2, 10.0f), Row(3, 15.5f)))

val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(allShuffles.nonEmpty,
"should contain a shuffle for the post-join window with partial clustering")

val scans = collectScans(df.queryExecution.executedPlan)
assert(scans.exists(_.outputPartitioning match {
case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered
case _ => false
}), "at least one BatchScanExec should have partially-clustered KeyGroupedPartitioning")
}
}

test("SPARK-55848: checkpointed partially-clustered join with dedup") {
withTempDir { dir =>
spark.sparkContext.setCheckpointDir(dir.getPath)
val items_partitions = Array(identity("id"))
createTable(items, itemsColumns, items_partitions)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " +
"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))")

val purchases_partitions = Array(identity("item_id"))
createTable(purchases, purchasesColumns, purchases_partitions)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 50.0, cast('2020-01-02' as timestamp)), " +
"(2, 11.0, cast('2020-01-01' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> true.toString) {
// Checkpoint the JOIN result (not the scan) so the plan behind the
// checkpoint carries partially-clustered KeyGroupedPartitioning.
// The dedup on top must still insert an Exchange because the
// isPartiallyClustered flag causes satisfies0()=false for
// ClusteredDistribution.
val joinedDf = spark.sql(
s"""SELECT i.id, i.name, i.price
|FROM testcat.ns.$items i
|JOIN testcat.ns.$purchases p ON i.id = p.item_id""".stripMargin)
val checkpointedDf = joinedDf.checkpoint()
val df = checkpointedDf.select("id").distinct()

checkAnswer(df, Seq(Row(1), Row(2), Row(3)))

val allShuffles = collectAllShuffles(df.queryExecution.executedPlan)
assert(allShuffles.nonEmpty,
"should contain a shuffle for the dedup after checkpointed " +
"partially-clustered join")

val rddScans = collect(df.queryExecution.executedPlan) {
case r: RDDScanExec => r
}
assert(rddScans.exists(_.outputPartitioning match {
case kgp: physical.KeyGroupedPartitioning => kgp.isPartiallyClustered
case _ => false
}), "checkpoint (RDDScanExec) should have " +
"partially-clustered KeyGroupedPartitioning")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession {
EnsureRequirements.apply(smjExec) match {
case ShuffledHashJoinExec(_, _, _, _, _,
DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _),
ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _, _),
DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) =>
assert(left.expressions == a1 :: Nil)
assert(attrs == a1 :: Nil)
Expand Down