-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-55147][SS] Scope timestamp range for time-interval join retrieval in V4 state format #54879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,11 +67,15 @@ trait SymmetricHashJoinStateManager { | |
| * required to do so. | ||
| * | ||
| * It is caller's responsibility to consume the whole iterator. | ||
| * | ||
| * For V4 time-interval joins, timestampRange may be provided to skip/stop-early during | ||
| * prefix scan. Ignored by V1-V3. | ||
| */ | ||
| def getJoinedRows( | ||
| key: UnsafeRow, | ||
| generateJoinedRow: InternalRow => JoinedRow, | ||
| predicate: JoinedRow => Boolean): Iterator[JoinedRow] | ||
| predicate: JoinedRow => Boolean, | ||
| timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] | ||
|
|
||
| /** | ||
| * Retrieve all joined rows for the given key and remove the matched rows from state. The joined | ||
|
|
@@ -343,9 +347,8 @@ class SymmetricHashJoinStateManagerV4( | |
| override def getJoinedRows( | ||
| key: UnsafeRow, | ||
| generateJoinedRow: InternalRow => JoinedRow, | ||
| predicate: JoinedRow => Boolean): Iterator[JoinedRow] = { | ||
| // TODO: [SPARK-55147] We could improve this method to get the scope of timestamp and scan keys | ||
| // more efficiently. For now, we just get all values for the key. | ||
| predicate: JoinedRow => Boolean, | ||
| timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] = { | ||
| def getJoinedRowsFromTsAndValues( | ||
| ts: Long, | ||
| valuesAndMatched: Array[ValueAndMatchPair]): Iterator[JoinedRow] = { | ||
|
|
@@ -399,7 +402,8 @@ class SymmetricHashJoinStateManagerV4( | |
| getJoinedRowsFromTsAndValues(ts, valuesAndMatchedIter.toArray) | ||
|
|
||
| case _ => | ||
| keyWithTsToValues.getValues(key).flatMap { result => | ||
| val (minTs, maxTs) = timestampRange.getOrElse((Long.MinValue, Long.MaxValue)) | ||
| keyWithTsToValues.getValuesInRange(key, minTs, maxTs).flatMap { result => | ||
| val ts = result.timestamp | ||
| val valuesAndMatched = result.values.toArray | ||
| getJoinedRowsFromTsAndValues(ts, valuesAndMatched) | ||
|
|
@@ -626,67 +630,86 @@ class SymmetricHashJoinStateManagerV4( | |
|
|
||
| // NOTE: This assumes we consume the whole iterator to trigger completion. | ||
| def getValues(key: UnsafeRow): Iterator[GetValuesResult] = { | ||
| getValuesInRange(key, Long.MinValue, Long.MaxValue) | ||
| } | ||
|
|
||
| /** | ||
| * Returns entries where minTs <= timestamp <= maxTs, grouped by timestamp. | ||
| * Filters out entries before minTs and stops iterating past maxTs (timestamps are sorted). | ||
| */ | ||
| def getValuesInRange( | ||
| key: UnsafeRow, minTs: Long, maxTs: Long): Iterator[GetValuesResult] = { | ||
| val reusableGetValuesResult = new GetValuesResult() | ||
|
|
||
| new NextIterator[GetValuesResult] { | ||
| private val iter = stateStore.prefixScanWithMultiValues(key, colFamilyName) | ||
|
|
||
| private var currentTs = -1L | ||
| private var currentTsInRange = false | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly the new logic is over complicated. Technically, the entire logic should be the same as before, except handling lower bound and upper bound. Handling lower bound and upper bound can be handled as exceptional case than processing the data in timestamp boundary. Below is the simplified logic (DISCLAIMER: Claude 4.6 opus) with my guidance of direction for above simplification: Would you take a look and apply the change if you think it's good? |
||
| private var pastUpperBound = false | ||
| private val valueAndMatchPairs = scala.collection.mutable.ArrayBuffer[ValueAndMatchPair]() | ||
|
|
||
| private def flushIfInRange(): GetValuesResult = { | ||
| if (currentTsInRange && valueAndMatchPairs.nonEmpty) { | ||
| val result = reusableGetValuesResult.withNew( | ||
| currentTs, valueAndMatchPairs.toList) | ||
| currentTs = -1L | ||
| currentTsInRange = false | ||
| valueAndMatchPairs.clear() | ||
| result | ||
| } else { | ||
| finished = true | ||
| null | ||
| } | ||
| } | ||
|
|
||
| @tailrec | ||
| override protected def getNext(): GetValuesResult = { | ||
| if (iter.hasNext) { | ||
| if (pastUpperBound) { | ||
| flushIfInRange() | ||
| } else if (iter.hasNext) { | ||
| val unsafeRowPair = iter.next() | ||
|
|
||
| val ts = TimestampKeyStateEncoder.extractTimestamp(unsafeRowPair.key) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be considered as "beyond the scope of PR" (since it requires some code change on StateStore API), but just for recording: We might get better outcome if we can specify the start boundary and the end boundary when scanning in general. This would be useful for range scan and prefix scan with timestamp (timestamp as postfix). Seeking to specific position directly is better than seeking to first position and doing sequential scan to find the specific position. (RocksDB won't find the position via sequentially scan) Scoping the iterator to closely in upper bound would also help RocksDB to avoid unnecessary scanning (especially tombstones), although it's not probably very different due to the pattern we remove the state (we remove the state in timestamp order, so it's unlikely to have tombstones beyond upper bound if we ever have a valid entry within timestamp boundary. |
||
|
|
||
| if (currentTs == -1L) { | ||
| // First time | ||
| if (ts > maxTs) { | ||
| pastUpperBound = true | ||
| getNext() | ||
| } else if (currentTs == -1L) { | ||
| currentTs = ts | ||
| } | ||
|
|
||
| if (currentTs != ts) { | ||
| assert(valueAndMatchPairs.nonEmpty, | ||
| "timestamp has changed but no values collected from previous timestamp! " + | ||
| s"This should not happen. currentTs: $currentTs, new ts: $ts") | ||
|
|
||
| // Return previous batch | ||
| val result = reusableGetValuesResult.withNew( | ||
| currentTs, valueAndMatchPairs.toSeq) | ||
| currentTsInRange = ts >= minTs | ||
| if (currentTsInRange) { | ||
| valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) | ||
| } | ||
| getNext() | ||
| } else if (currentTs != ts) { | ||
| // Timestamp changed -- capture previous batch before resetting | ||
| val prevTs = currentTs | ||
| val prevValues = if (currentTsInRange && valueAndMatchPairs.nonEmpty) { | ||
| valueAndMatchPairs.toList | ||
| } else { | ||
| null | ||
| } | ||
|
|
||
| // Reset for new timestamp | ||
| currentTs = ts | ||
| currentTsInRange = ts >= minTs | ||
| valueAndMatchPairs.clear() | ||
| if (currentTsInRange) { | ||
| valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) | ||
| } | ||
|
|
||
| // Add current value | ||
| val value = valueRowConverter.convertValue(unsafeRowPair.value) | ||
| valueAndMatchPairs += value | ||
| result | ||
| if (prevValues != null) { | ||
| reusableGetValuesResult.withNew(prevTs, prevValues) | ||
| } else { | ||
| getNext() | ||
| } | ||
| } else { | ||
| // Same timestamp, accumulate values | ||
| val value = valueRowConverter.convertValue(unsafeRowPair.value) | ||
| valueAndMatchPairs += value | ||
|
|
||
| // Continue to next | ||
| if (currentTsInRange) { | ||
| valueAndMatchPairs += valueRowConverter.convertValue(unsafeRowPair.value) | ||
| } | ||
| getNext() | ||
| } | ||
| } else { | ||
| if (currentTs != -1L) { | ||
| assert(valueAndMatchPairs.nonEmpty) | ||
|
|
||
| // Return last batch | ||
| val result = reusableGetValuesResult.withNew( | ||
| currentTs, valueAndMatchPairs.toSeq) | ||
|
|
||
| // Mark as finished | ||
| currentTs = -1L | ||
| valueAndMatchPairs.clear() | ||
| result | ||
| } else { | ||
| finished = true | ||
| null | ||
| } | ||
| flushIfInRange() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1051,7 +1074,8 @@ abstract class SymmetricHashJoinStateManagerBase( | |
| def getJoinedRows( | ||
| key: UnsafeRow, | ||
| generateJoinedRow: InternalRow => JoinedRow, | ||
| predicate: JoinedRow => Boolean): Iterator[JoinedRow] = { | ||
| predicate: JoinedRow => Boolean, | ||
| timestampRange: Option[(Long, Long)] = None): Iterator[JoinedRow] = { | ||
| val numValues = keyToNumValues.get(key) | ||
| keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue => | ||
| val joinedRow = generateJoinedRow(keyIdxToValue.value) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1009,4 +1009,48 @@ class SymmetricHashJoinStateManagerEventTimeInValueSuite | |
| } | ||
| } | ||
| } | ||
|
|
||
| private def getJoinedRowTimestamps( | ||
| key: Int, | ||
| range: Option[(Long, Long)])(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { | ||
| val dummyRow = new GenericInternalRow(0) | ||
| manager.getJoinedRows( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's briefly leave code comment that |
||
| toJoinKeyRow(key), | ||
| row => new JoinedRow(row, dummyRow), | ||
| _ => true, | ||
| timestampRange = range | ||
| ).map(_.getInt(1)).toSeq.sorted | ||
| } | ||
|
|
||
| test("StreamingJoinStateManager V4 - getJoinedRows with timestampRange") { | ||
| withJoinStateManager( | ||
| inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => | ||
| implicit val mgr = manager | ||
|
|
||
| Seq(10, 20, 30, 40, 50).foreach(append(40, _)) | ||
|
|
||
| assert(getJoinedRowTimestamps(40, Some((20L, 40L))) === Seq(20, 30, 40)) | ||
| assert(getJoinedRowTimestamps(40, Some((20L, 20L))) === Seq(20)) | ||
| assert(getJoinedRowTimestamps(40, Some((25L, 35L))) === Seq(30)) | ||
| assert(getJoinedRowTimestamps(40, Some((0L, 100L))) === Seq(10, 20, 30, 40, 50)) | ||
| assert(getJoinedRowTimestamps(40, Some((10L, 30L))) === Seq(10, 20, 30)) | ||
| assert(getJoinedRowTimestamps(40, Some((50L, 100L))) === Seq(50)) | ||
| assert(getJoinedRowTimestamps(40, Some((60L, 100L))) === Seq.empty) | ||
| assert(getJoinedRowTimestamps(40, Some((0L, 5L))) === Seq.empty) | ||
| assert(getJoinedRowTimestamps(40, None) === Seq(10, 20, 30, 40, 50)) | ||
| } | ||
| } | ||
|
|
||
| test("StreamingJoinStateManager V4 - timestampRange with multiple values per timestamp") { | ||
| withJoinStateManager( | ||
| inputValueAttributes, joinKeyExpressions, stateFormatVersion = 4) { manager => | ||
| implicit val mgr = manager | ||
|
|
||
| append(40, 20) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not adding a bit more to show that the method works correctly with multiple values across multiple timestamp buckets? |
||
| append(40, 20) // same timestamp bucket | ||
| append(40, 30) | ||
|
|
||
| assert(getJoinedRowTimestamps(40, Some((20L, 20L))) === Seq(20, 20)) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep the documentation of interface to take a role of "interface". We can just talk about contract/requirement and derived classes will have separate method doc to mention the fact.
I'd rather document the parameter
timestampRangeas "hint" for optimization of reducing scope of scan. The derived class can make an optimization with that hint but it's still OK for derived class to ignore it, if the derived class cannot leverage that hint.We should still also clarify that given
timestampRangeis a hint and derived class can decide not to leverage it,timestampRangeis expected to be a subset ofpredicatecondition in practice. That means, the parameterpredicatehas to be provided in a way to produce the correct output whether the parametertimestampRangeis leveraged as a hint or not.(The reason I said "in practice" is because you leverage this in the test and I admit there is no easy way to test this except breaking the above.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also let's clarify the boundary of inclusive vs exclusive in both sides. This should be described in the interface method doc.