Skip to content

Commit 0ba6e70

Browse files
authored
Merge SortMergeJoin filtered batches into larger batches (#14160)
* Merge SortMergeJoin filtered batches into bigger batches
1 parent 274e535 commit 0ba6e70

File tree

1 file changed

+95
-40
lines changed

1 file changed

+95
-40
lines changed

datafusion/physical-plan/src/joins/sort_merge_join.rs

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,10 @@ struct SortMergeJoinStream {
792792
/// optional join filter
793793
pub filter: Option<JoinFilter>,
794794
/// Staging output array builders
795-
pub output_record_batches: JoinedRecordBatches,
795+
pub staging_output_record_batches: JoinedRecordBatches,
796+
/// Output buffer. Currently used by filtering as it requires double buffering
797+
/// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches`
798+
pub output: RecordBatch,
796799
/// Staging output size, including output batches and staging joined results.
797800
/// Increased when we put rows into buffer and decreased after we actually output batches.
798801
/// Used to trigger output when sufficient rows are ready
@@ -1053,13 +1056,35 @@ impl Stream for SortMergeJoinStream {
10531056
{
10541057
self.freeze_all()?;
10551058

1056-
if !self.output_record_batches.batches.is_empty()
1059+
// If join is filtered and there is joined tuples waiting
1060+
// to be filtered
1061+
if !self
1062+
.staging_output_record_batches
1063+
.batches
1064+
.is_empty()
10571065
{
1066+
// Apply filter on joined tuples and get filtered batch
10581067
let out_filtered_batch =
10591068
self.filter_joined_batch()?;
1060-
return Poll::Ready(Some(Ok(
1061-
out_filtered_batch,
1062-
)));
1069+
1070+
// Append filtered batch to the output buffer
1071+
self.output = concat_batches(
1072+
&self.schema(),
1073+
vec![&self.output, &out_filtered_batch],
1074+
)?;
1075+
1076+
// Send to output if the output buffer surpassed the `batch_size`
1077+
if self.output.num_rows() >= self.batch_size {
1078+
let record_batch = std::mem::replace(
1079+
&mut self.output,
1080+
RecordBatch::new_empty(
1081+
out_filtered_batch.schema(),
1082+
),
1083+
);
1084+
return Poll::Ready(Some(Ok(
1085+
record_batch,
1086+
)));
1087+
}
10631088
}
10641089
}
10651090

@@ -1116,7 +1141,7 @@ impl Stream for SortMergeJoinStream {
11161141
}
11171142
} else {
11181143
self.freeze_all()?;
1119-
if !self.output_record_batches.batches.is_empty() {
1144+
if !self.staging_output_record_batches.batches.is_empty() {
11201145
let record_batch = self.output_record_batch_and_reset()?;
11211146
// For non-filtered join output whenever the target output batch size
11221147
// is hit. For filtered join its needed to output on later phase
@@ -1146,7 +1171,8 @@ impl Stream for SortMergeJoinStream {
11461171
SortMergeJoinState::Exhausted => {
11471172
self.freeze_all()?;
11481173

1149-
if !self.output_record_batches.batches.is_empty() {
1174+
// if there is still something not processed
1175+
if !self.staging_output_record_batches.batches.is_empty() {
11501176
if self.filter.is_some()
11511177
&& matches!(
11521178
self.join_type,
@@ -1159,12 +1185,20 @@ impl Stream for SortMergeJoinStream {
11591185
| JoinType::LeftMark
11601186
)
11611187
{
1162-
let out = self.filter_joined_batch()?;
1163-
return Poll::Ready(Some(Ok(out)));
1188+
let record_batch = self.filter_joined_batch()?;
1189+
return Poll::Ready(Some(Ok(record_batch)));
11641190
} else {
11651191
let record_batch = self.output_record_batch_and_reset()?;
11661192
return Poll::Ready(Some(Ok(record_batch)));
11671193
}
1194+
} else if self.output.num_rows() > 0 {
1195+
// if processed but still not outputted because it didn't hit batch size before
1196+
let schema = self.output.schema();
1197+
let record_batch = std::mem::replace(
1198+
&mut self.output,
1199+
RecordBatch::new_empty(schema),
1200+
);
1201+
return Poll::Ready(Some(Ok(record_batch)));
11681202
} else {
11691203
return Poll::Ready(None);
11701204
}
@@ -1197,7 +1231,7 @@ impl SortMergeJoinStream {
11971231
state: SortMergeJoinState::Init,
11981232
sort_options,
11991233
null_equals_null,
1200-
schema,
1234+
schema: Arc::clone(&schema),
12011235
streamed_schema: Arc::clone(&streamed_schema),
12021236
buffered_schema,
12031237
streamed,
@@ -1212,12 +1246,13 @@ impl SortMergeJoinStream {
12121246
on_streamed,
12131247
on_buffered,
12141248
filter,
1215-
output_record_batches: JoinedRecordBatches {
1249+
staging_output_record_batches: JoinedRecordBatches {
12161250
batches: vec![],
12171251
filter_mask: BooleanBuilder::new(),
12181252
row_indices: UInt64Builder::new(),
12191253
batch_ids: vec![],
12201254
},
1255+
output: RecordBatch::new_empty(schema),
12211256
output_size: 0,
12221257
batch_size,
12231258
join_type,
@@ -1607,17 +1642,20 @@ impl SortMergeJoinStream {
16071642
buffered_batch,
16081643
)? {
16091644
let num_rows = record_batch.num_rows();
1610-
self.output_record_batches
1645+
self.staging_output_record_batches
16111646
.filter_mask
16121647
.append_nulls(num_rows);
1613-
self.output_record_batches
1648+
self.staging_output_record_batches
16141649
.row_indices
16151650
.append_nulls(num_rows);
1616-
self.output_record_batches
1617-
.batch_ids
1618-
.resize(self.output_record_batches.batch_ids.len() + num_rows, 0);
1651+
self.staging_output_record_batches.batch_ids.resize(
1652+
self.staging_output_record_batches.batch_ids.len() + num_rows,
1653+
0,
1654+
);
16191655

1620-
self.output_record_batches.batches.push(record_batch);
1656+
self.staging_output_record_batches
1657+
.batches
1658+
.push(record_batch);
16211659
}
16221660
buffered_batch.null_joined.clear();
16231661
}
@@ -1651,16 +1689,19 @@ impl SortMergeJoinStream {
16511689
)? {
16521690
let num_rows = record_batch.num_rows();
16531691

1654-
self.output_record_batches
1692+
self.staging_output_record_batches
16551693
.filter_mask
16561694
.append_nulls(num_rows);
1657-
self.output_record_batches
1695+
self.staging_output_record_batches
16581696
.row_indices
16591697
.append_nulls(num_rows);
1660-
self.output_record_batches
1661-
.batch_ids
1662-
.resize(self.output_record_batches.batch_ids.len() + num_rows, 0);
1663-
self.output_record_batches.batches.push(record_batch);
1698+
self.staging_output_record_batches.batch_ids.resize(
1699+
self.staging_output_record_batches.batch_ids.len() + num_rows,
1700+
0,
1701+
);
1702+
self.staging_output_record_batches
1703+
.batches
1704+
.push(record_batch);
16641705
}
16651706
buffered_batch.join_filter_not_matched_map.clear();
16661707

@@ -1792,20 +1833,29 @@ impl SortMergeJoinStream {
17921833
| JoinType::LeftMark
17931834
| JoinType::Full
17941835
) {
1795-
self.output_record_batches.batches.push(output_batch);
1836+
self.staging_output_record_batches
1837+
.batches
1838+
.push(output_batch);
17961839
} else {
17971840
let filtered_batch = filter_record_batch(&output_batch, &mask)?;
1798-
self.output_record_batches.batches.push(filtered_batch);
1841+
self.staging_output_record_batches
1842+
.batches
1843+
.push(filtered_batch);
17991844
}
18001845

18011846
if !matches!(self.join_type, JoinType::Full) {
1802-
self.output_record_batches.filter_mask.extend(&mask);
1847+
self.staging_output_record_batches.filter_mask.extend(&mask);
18031848
} else {
1804-
self.output_record_batches.filter_mask.extend(pre_mask);
1849+
self.staging_output_record_batches
1850+
.filter_mask
1851+
.extend(pre_mask);
18051852
}
1806-
self.output_record_batches.row_indices.extend(&left_indices);
1807-
self.output_record_batches.batch_ids.resize(
1808-
self.output_record_batches.batch_ids.len() + left_indices.len(),
1853+
self.staging_output_record_batches
1854+
.row_indices
1855+
.extend(&left_indices);
1856+
self.staging_output_record_batches.batch_ids.resize(
1857+
self.staging_output_record_batches.batch_ids.len()
1858+
+ left_indices.len(),
18091859
self.streamed_batch_counter.load(Relaxed),
18101860
);
18111861

@@ -1837,10 +1887,14 @@ impl SortMergeJoinStream {
18371887
}
18381888
}
18391889
} else {
1840-
self.output_record_batches.batches.push(output_batch);
1890+
self.staging_output_record_batches
1891+
.batches
1892+
.push(output_batch);
18411893
}
18421894
} else {
1843-
self.output_record_batches.batches.push(output_batch);
1895+
self.staging_output_record_batches
1896+
.batches
1897+
.push(output_batch);
18441898
}
18451899
}
18461900

@@ -1851,7 +1905,7 @@ impl SortMergeJoinStream {
18511905

18521906
fn output_record_batch_and_reset(&mut self) -> Result<RecordBatch> {
18531907
let record_batch =
1854-
concat_batches(&self.schema, &self.output_record_batches.batches)?;
1908+
concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
18551909
self.join_metrics.output_batches.add(1);
18561910
self.join_metrics.output_rows.add(record_batch.num_rows());
18571911
// If join filter exists, `self.output_size` is not accurate as we don't know the exact
@@ -1877,16 +1931,17 @@ impl SortMergeJoinStream {
18771931
| JoinType::Full
18781932
))
18791933
{
1880-
self.output_record_batches.batches.clear();
1934+
self.staging_output_record_batches.batches.clear();
18811935
}
18821936
Ok(record_batch)
18831937
}
18841938

18851939
fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
1886-
let record_batch = self.output_record_batch_and_reset()?;
1887-
let mut out_indices = self.output_record_batches.row_indices.finish();
1888-
let mut out_mask = self.output_record_batches.filter_mask.finish();
1889-
let mut batch_ids = &self.output_record_batches.batch_ids;
1940+
let record_batch =
1941+
concat_batches(&self.schema, &self.staging_output_record_batches.batches)?;
1942+
let mut out_indices = self.staging_output_record_batches.row_indices.finish();
1943+
let mut out_mask = self.staging_output_record_batches.filter_mask.finish();
1944+
let mut batch_ids = &self.staging_output_record_batches.batch_ids;
18901945
let default_batch_ids = vec![0; record_batch.num_rows()];
18911946

18921947
// If only nulls come in and indices sizes doesn't match with expected record batch count
@@ -1901,7 +1956,7 @@ impl SortMergeJoinStream {
19011956
}
19021957

19031958
if out_mask.is_empty() {
1904-
self.output_record_batches.batches.clear();
1959+
self.staging_output_record_batches.batches.clear();
19051960
return Ok(record_batch);
19061961
}
19071962

@@ -2044,7 +2099,7 @@ impl SortMergeJoinStream {
20442099
)?;
20452100
}
20462101

2047-
self.output_record_batches.clear();
2102+
self.staging_output_record_batches.clear();
20482103

20492104
Ok(filtered_record_batch)
20502105
}

0 commit comments

Comments
 (0)