Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import org.apache.hadoop.conf.Configuration

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -682,6 +683,50 @@ case class StreamingSymmetricHashJoinExec(
private[this] val allowMultipleStatefulOperators: Boolean =
conf.getConf(SQLConf.STATEFUL_OPERATOR_ALLOW_MULTIPLE)

// V4 range scan for time-interval joins (SPARK-55147). Extracts constant interval
// offsets from the join condition using getStateValueWatermark(eventWatermark=0).
// The -1 eviction adjustment widens range by ~1ms/side; postJoinFilter handles exact bounds.
private[this] val scanRangeOffsets: Option[(Long, Long)] = {
val isV4TimeIntervalJoin = stateFormatVersion >= 4 && (stateWatermarkPredicate match {
case Some(_: JoinStateValueWatermarkPredicate) => true
case _ => false
})

if (!isV4TimeIntervalJoin) {
None
} else {
val (thisSideAttrs, otherSideAttrs) = joinSide match {
case LeftSide => (left.output, right.output)
case RightSide => (right.output, left.output)
}

val lowerBoundMs = StreamingJoinHelper.getStateValueWatermark(
AttributeSet(otherSideAttrs), AttributeSet(thisSideAttrs), condition.full, Some(0L))
val upperBoundMs = StreamingJoinHelper.getStateValueWatermark(
AttributeSet(thisSideAttrs), AttributeSet(otherSideAttrs), condition.full, Some(0L))

(lowerBoundMs, upperBoundMs) match {
case (Some(lower), Some(upper)) =>
Some((lower * 1000L, -upper * 1000L)) // ms -> us
case _ => None
}
}
}

private[this] val eventTimeIdxForRangeScan: Int = scanRangeOffsets.map { _ =>
WatermarkSupport.findEventTimeColumnIndex(
inputAttributes, !allowMultipleStatefulOperators).getOrElse(-1)
}.getOrElse(-1)

private def computeTimestampRange(thisRow: UnsafeRow): Option[(Long, Long)] = {
scanRangeOffsets match {
case Some((lowerOffset, upperOffset)) if eventTimeIdxForRangeScan >= 0 =>
val eventTimeUs = thisRow.getLong(eventTimeIdxForRangeScan)
Some((eventTimeUs + lowerOffset, eventTimeUs + upperOffset))
case _ => None
}
}

/**
* Generate joined rows by consuming input from this side, and matching it with the buffered
* rows (i.e. state) of the other side.
Expand Down Expand Up @@ -758,7 +803,8 @@ case class StreamingSymmetricHashJoinExec(
otherSideJoiner.joinStateManager.getJoinedRows(
key,
thatRow => generateJoinedRow(thisRow, thatRow),
postJoinFilter)
postJoinFilter,
timestampRange = computeTimestampRange(thisRow))
}
val outputIter = generateOutputIter(thisRow, joinedRowIter)
new AddingProcessedRowToStateCompletionIterator(key, thisRow, outputIter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,15 @@ trait SymmetricHashJoinStateManager {
* required to do so.
*
* It is caller's responsibility to consume the whole iterator.
*
* For V4 time-interval joins, timestampRange may be provided to skip/stop-early during
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the documentation of interface to take a role of "interface". We can just talk about contract/requirement and derived classes will have separate method doc to mention the fact.

I'd rather document the parameter timestampRange as "hint" for optimization of reducing scope of scan. The derived class can make an optimization with that hint but it's still OK for derived class to ignore it, if the derived class cannot leverage that hint.

We should still also clarify that given timestampRange is a hint and derived class can decide not to leverage it, timestampRange is expected to be a subset of predicate condition in practice. That means, the parameter predicate has to be provided in a way to produce the correct output whether the parameter timestampRange is leveraged as a hint or not.

(The reason I said "in practice" is because you leverage this in the test and I admit there is no easy way to test this except breaking the above.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also let's clarify the boundary of inclusive vs exclusive in both sides. This should be described in the interface method doc.

* prefix scan. Ignored by V1-V3.
*/
def getJoinedRows(
key: UnsafeRow,
generateJoinedRow: InternalRow => JoinedRow,
predicate: JoinedRow => Boolean): Iterator[JoinedRow]
predicate: JoinedRow => Boolean,
timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow]

/**
* Retrieve all joined rows for the given key and remove the matched rows from state. The joined
Expand Down Expand Up @@ -343,9 +347,8 @@ class SymmetricHashJoinStateManagerV4(
override def getJoinedRows(
key: UnsafeRow,
generateJoinedRow: InternalRow => JoinedRow,
predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
// TODO: [SPARK-55147] We could improve this method to get the scope of timestamp and scan keys
// more efficiently. For now, we just get all values for the key.
predicate: JoinedRow => Boolean,
timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] = {
def getJoinedRowsFromTsAndValues(
ts: Long,
valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = {
Expand Down Expand Up @@ -399,7 +402,8 @@ class SymmetricHashJoinStateManagerV4(
getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray)

case _ =>
keyWithTsToValues.getValues(key).flatMap { result =>
val (minTs, maxTs) = timestampRange.getOrElse((Long.MinValue, Long.MaxValue))
keyWithTsToValues.getValuesInRange(key, minTs, maxTs).flatMap { result =>
val ts = result.timestamp
val valuesAndMatched = result.values.toArray
getJoinedRowsFromTsAndValues(ts, valuesAndMatched)
Expand Down Expand Up @@ -626,67 +630,86 @@ class SymmetricHashJoinStateManagerV4(

// NOTE: This assumes we consume the whole iterator to trigger completion.
def getValues(key: UnsafeRow): Iterator[GetValuesResult] = {
getValuesInRange(key, Long.MinValue, Long.MaxValue)
}

/**
* Returns entries where minTs <= timestamp <= maxTs, grouped by timestamp.
* Filters out entries before minTs and stops iterating past maxTs (timestamps are sorted).
*/
def getValuesInRange(
key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = {
val reusableGetValuesResult = new GetValuesResult()

new NextIterator[GetValuesResult] {
private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName)

private var currentTs = -1L
private var currentTsInRange = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly the new logic is over complicated. Technically, the entire logic should be the same as before, except handling lower bound and upper bound. Handling lower bound and upper bound can be handled as exceptional case than processing the data in timestamp boundary.

Below is the simplified logic (DISCLAIMER: Claude 4.6 opus) with my guidance of direction for above simplification:

        private var currentTs = -1L
        private var pastUpperBound = false
        private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()

        private def flushAccumulated(): GetValuesResult = {
          if (valueAndMatchPairs.nonEmpty) {
            val result = reusableGetValuesResult.withNew(
              currentTs, valueAndMatchPairs.toList)
            currentTs = -1L
            valueAndMatchPairs.clear()
            result
          } else {
            finished = true
            null
          }
        }

        @tailrec
        override protected def getNext(): GetValuesResult = {
          if (pastUpperBound || !iter.hasNext) {
            flushAccumulated()
          } else {
            val unsafeRowPair = iter.next()
            val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)

            if (ts > maxTs) {
              pastUpperBound = true
              getNext()
            } else if (ts < minTs) {
              getNext()
            } else if (currentTs == -1L || currentTs == ts) {
              currentTs = ts
              valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
              getNext()
            } else {
              // Timestamp changed -- flush previous group before starting new one
              val prevTs = currentTs
              val prevValues = valueAndMatchPairs.toList

              currentTs = ts
              valueAndMatchPairs.clear()
              valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)

              reusableGetValuesResult.withNew(prevTs, prevValues)
            }
          }
        }

Would you take a look and apply the change if you think it's good?

private var pastUpperBound = false
private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]()

private def flushIfInRange(): GetValuesResult = {
if (currentTsInRange && valueAndMatchPairs.nonEmpty) {
val result = reusableGetValuesResult.withNew(
currentTs, valueAndMatchPairs.toList)
currentTs = -1L
currentTsInRange = false
valueAndMatchPairs.clear()
result
} else {
finished = true
null
}
}

@tailrec
override protected def getNext(): GetValuesResult = {
if (iter.hasNext) {
if (pastUpperBound) {
flushIfInRange()
} else if (iter.hasNext) {
val unsafeRowPair = iter.next()

val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be considered as "beyond the scope of PR" (since it requires some code change on StateStore API), but just for recording:

We might get better outcome if we can specify the start boundary and the end boundary when scanning in general. This would be useful for range scan and prefix scan with timestamp (timestamp as postfix).

Seeking to specific position directly is better than seeking to first position and doing sequential scan to find the specific position. (RocksDB won't find the position via sequentially scan)

Scoping the iterator to closely in upper bound would also help RocksDB to avoid unnecessary scanning (especially tombstones), although it's not probably very different due to the pattern we remove the state (we remove the state in timestamp order, so it's unlikely to have tombstones beyond upper bound if we ever have a valid entry within timestamp boundary.


if (currentTs == -1L) {
// First time
if (ts > maxTs) {
pastUpperBound = true
getNext()
} else if (currentTs == -1L) {
currentTs = ts
}

if (currentTs != ts) {
assert(valueAndMatchPairs.nonEmpty,
"timestamp has changed but no values collected from previous timestamp! " +
s"This should not happen. currentTs: $currentTs, new ts: $ts")

// Return previous batch
val result = reusableGetValuesResult.withNew(
currentTs, valueAndMatchPairs.toSeq)
currentTsInRange = ts >= minTs
if (currentTsInRange) {
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
}
getNext()
} else if (currentTs != ts) {
// Timestamp changed -- capture previous batch before resetting
val prevTs = currentTs
val prevValues = if (currentTsInRange && valueAndMatchPairs.nonEmpty) {
valueAndMatchPairs.toList
} else {
null
}

// Reset for new timestamp
currentTs = ts
currentTsInRange = ts >= minTs
valueAndMatchPairs.clear()
if (currentTsInRange) {
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
}

// Add current value
val value = valueRowConverter.convertValue(unsafeRowPair.value)
valueAndMatchPairs += value
result
if (prevValues != null) {
reusableGetValuesResult.withNew(prevTs, prevValues)
} else {
getNext()
}
} else {
// Same timestamp, accumulate values
val value = valueRowConverter.convertValue(unsafeRowPair.value)
valueAndMatchPairs += value

// Continue to next
if (currentTsInRange) {
valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value)
}
getNext()
}
} else {
if (currentTs != -1L) {
assert(valueAndMatchPairs.nonEmpty)

// Return last batch
val result = reusableGetValuesResult.withNew(
currentTs, valueAndMatchPairs.toSeq)

// Mark as finished
currentTs = -1L
valueAndMatchPairs.clear()
result
} else {
finished = true
null
}
flushIfInRange()
}
}

Expand Down Expand Up @@ -1051,7 +1074,8 @@ abstract class SymmetricHashJoinStateManagerBase(
def getJoinedRows(
key: UnsafeRow,
generateJoinedRow: InternalRow => JoinedRow,
predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
predicate: JoinedRow => Boolean,
timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] = {
val numValues = keyToNumValues.get(key)
keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
val joinedRow = generateJoinedRow(keyIdxToValue.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,4 +1009,48 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite
}
}
}

private def getJoinedRowTimestamps(
key: Int,
range: Option[(Long, Long)])(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = {
val dummyRow = new GenericInternalRow(0)
manager.getJoinedRows(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's briefly leave code comment that predicate should contain the condition of timestampRange in practice, but we don't do it intentionally for testing the functionality.

toJoinKeyRow(key),
row => new JoinedRow(row, dummyRow),
_ => true,
timestampRange = range
).map(_.getInt(1)).toSeq.sorted
}

test("StreamingJoinStateManager V4 - getJoinedRows with timestampRange") {
withJoinStateManager(
inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager =>
implicit val mgr = manager

Seq(10, 20, 30, 40, 50).foreach(append(40, _))

assert(getJoinedRowTimestamps(40, Some((20L, 40L))) === Seq(20, 30, 40))
assert(getJoinedRowTimestamps(40, Some((20L, 20L))) === Seq(20))
assert(getJoinedRowTimestamps(40, Some((25L, 35L))) === Seq(30))
assert(getJoinedRowTimestamps(40, Some((0L, 100L))) === Seq(10, 20, 30, 40, 50))
assert(getJoinedRowTimestamps(40, Some((10L, 30L))) === Seq(10, 20, 30))
assert(getJoinedRowTimestamps(40, Some((50L, 100L))) === Seq(50))
assert(getJoinedRowTimestamps(40, Some((60L, 100L))) === Seq.empty)
assert(getJoinedRowTimestamps(40, Some((0L, 5L))) === Seq.empty)
assert(getJoinedRowTimestamps(40, None) === Seq(10, 20, 30, 40, 50))
}
}

test("StreamingJoinStateManager V4 - timestampRange with multiple values per timestamp") {
withJoinStateManager(
inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager =>
implicit val mgr = manager

append(40, 20)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not adding a bit more to show that the method works correctly with multiple values across multiple timestamp buckets?

append(40, 20) // same timestamp bucket
append(40, 30)

assert(getJoinedRowTimestamps(40, Some((20L, 20L))) === Seq(20, 20))
}
}
}