Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ pub(super) struct StreamedBatch {
pub join_arrays: Vec<ArrayRef>,
/// Chunks of indices from buffered side (may be nulls) joined to streamed
pub output_indices: Vec<StreamedJoinedChunk>,
/// Total number of output rows across all chunks in `output_indices`
pub num_output_rows: usize,
/// Index of currently scanned batch from buffered data
pub buffered_batch_idx: Option<usize>,
/// Indices that found a match for the given join filter
Expand All @@ -142,6 +144,7 @@ impl StreamedBatch {
idx: 0,
join_arrays,
output_indices: vec![],
num_output_rows: 0,
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
Expand All @@ -153,17 +156,15 @@ impl StreamedBatch {
idx: 0,
join_arrays: vec![],
output_indices: vec![],
num_output_rows: 0,
buffered_batch_idx: None,
join_filter_matched_idxs: HashSet::new(),
}
}

/// Number of unfrozen output pairs in this streamed batch
fn num_output_rows(&self) -> usize {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think this function can be removed, we can call num_unfrozen_pairs() as streamed_batch.num_output_rows. and add a small comment for the num_output_rows field to declare that it represents unfrozen pairs

self.output_indices
.iter()
.map(|chunk| chunk.streamed_indices.len())
.sum()
self.num_output_rows
}

/// Appends new pair consisting of current streamed index and `buffered_idx`
Expand All @@ -173,20 +174,20 @@ impl StreamedBatch {
buffered_batch_idx: Option<usize>,
buffered_idx: Option<usize>,
batch_size: usize,
num_unfrozen_pairs: usize,
) {
// If no current chunk exists or current chunk is not for current buffered batch,
// create a new chunk
if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
{
// Compute capacity only when creating a new chunk (infrequent operation).
// The capacity is the remaining space to reach batch_size.
// This should always be >= 1 since we only call this when num_unfrozen_pairs < batch_size.
// This should always be >= 1 since we only call this when num_output_rows < batch_size.
debug_assert!(
batch_size > num_unfrozen_pairs,
"batch_size ({batch_size}) must be > num_unfrozen_pairs ({num_unfrozen_pairs})"
batch_size > self.num_output_rows,
"batch_size ({batch_size}) must be > num_output_rows ({})",
self.num_output_rows
);
let capacity = batch_size - num_unfrozen_pairs;
let capacity = batch_size - self.num_output_rows;
self.output_indices.push(StreamedJoinedChunk {
buffered_batch_idx,
streamed_indices: UInt64Builder::with_capacity(capacity),
Expand All @@ -203,6 +204,7 @@ impl StreamedBatch {
} else {
current_chunk.buffered_indices.append_null();
}
self.num_output_rows += 1;
}
}

Expand Down Expand Up @@ -1100,13 +1102,10 @@ impl SortMergeJoinStream {
let scanning_idx = self.buffered_data.scanning_idx();
if join_streamed {
// Join streamed row and buffered row
// Pass batch_size and num_unfrozen_pairs to compute capacity only when
// creating a new chunk (when buffered_batch_idx changes), not on every iteration.
self.streamed_batch.append_output_pair(
Some(self.buffered_data.scanning_batch_idx),
Some(scanning_idx),
self.batch_size,
self.num_unfrozen_pairs(),
);
} else {
// Join nulls and buffered row for FULL join
Expand All @@ -1132,13 +1131,10 @@ impl SortMergeJoinStream {
// For Mark join we store a dummy id to indicate the row has a match
let scanning_idx = mark_row_as_match.then_some(0);

// Pass batch_size=1 and num_unfrozen_pairs=0 to get capacity of 1,
// since we only append a single null-joined pair here (not in a loop).
self.streamed_batch.append_output_pair(
scanning_batch_idx,
scanning_idx,
1,
0,
self.batch_size,
);
self.buffered_data.scanning_finish();
self.streamed_joined = true;
Expand Down Expand Up @@ -1437,6 +1433,7 @@ impl SortMergeJoinStream {
}

self.streamed_batch.output_indices.clear();
self.streamed_batch.num_output_rows = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably encapsulate this reset pattern into its own function? (self.reset() calls .clear() and sets num_output_rows = 0)


Ok(())
}
Expand Down