@@ -792,7 +792,10 @@ struct SortMergeJoinStream {
792
792
/// optional join filter
793
793
pub filter : Option < JoinFilter > ,
794
794
/// Staging output array builders
795
- pub output_record_batches : JoinedRecordBatches ,
795
+ pub staging_output_record_batches : JoinedRecordBatches ,
796
+ /// Output buffer. Currently used by filtering as it requires double buffering
797
+ /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches`
798
+ pub output : RecordBatch ,
796
799
/// Staging output size, including output batches and staging joined results.
797
800
/// Increased when we put rows into buffer and decreased after we actually output batches.
798
801
/// Used to trigger output when sufficient rows are ready
@@ -1053,13 +1056,35 @@ impl Stream for SortMergeJoinStream {
1053
1056
{
1054
1057
self . freeze_all ( ) ?;
1055
1058
1056
- if !self . output_record_batches . batches . is_empty ( )
1059
+ // If join is filtered and there is joined tuples waiting
1060
+ // to be filtered
1061
+ if !self
1062
+ . staging_output_record_batches
1063
+ . batches
1064
+ . is_empty ( )
1057
1065
{
1066
+ // Apply filter on joined tuples and get filtered batch
1058
1067
let out_filtered_batch =
1059
1068
self . filter_joined_batch ( ) ?;
1060
- return Poll :: Ready ( Some ( Ok (
1061
- out_filtered_batch,
1062
- ) ) ) ;
1069
+
1070
+ // Append filtered batch to the output buffer
1071
+ self . output = concat_batches (
1072
+ & self . schema ( ) ,
1073
+ vec ! [ & self . output, & out_filtered_batch] ,
1074
+ ) ?;
1075
+
1076
+ // Send to output if the output buffer surpassed the `batch_size`
1077
+ if self . output . num_rows ( ) >= self . batch_size {
1078
+ let record_batch = std:: mem:: replace (
1079
+ & mut self . output ,
1080
+ RecordBatch :: new_empty (
1081
+ out_filtered_batch. schema ( ) ,
1082
+ ) ,
1083
+ ) ;
1084
+ return Poll :: Ready ( Some ( Ok (
1085
+ record_batch,
1086
+ ) ) ) ;
1087
+ }
1063
1088
}
1064
1089
}
1065
1090
@@ -1116,7 +1141,7 @@ impl Stream for SortMergeJoinStream {
1116
1141
}
1117
1142
} else {
1118
1143
self . freeze_all ( ) ?;
1119
- if !self . output_record_batches . batches . is_empty ( ) {
1144
+ if !self . staging_output_record_batches . batches . is_empty ( ) {
1120
1145
let record_batch = self . output_record_batch_and_reset ( ) ?;
1121
1146
// For non-filtered join output whenever the target output batch size
1122
1147
// is hit. For filtered join its needed to output on later phase
@@ -1146,7 +1171,8 @@ impl Stream for SortMergeJoinStream {
1146
1171
SortMergeJoinState :: Exhausted => {
1147
1172
self . freeze_all ( ) ?;
1148
1173
1149
- if !self . output_record_batches . batches . is_empty ( ) {
1174
+ // if there is still something not processed
1175
+ if !self . staging_output_record_batches . batches . is_empty ( ) {
1150
1176
if self . filter . is_some ( )
1151
1177
&& matches ! (
1152
1178
self . join_type,
@@ -1159,12 +1185,20 @@ impl Stream for SortMergeJoinStream {
1159
1185
| JoinType :: LeftMark
1160
1186
)
1161
1187
{
1162
- let out = self . filter_joined_batch ( ) ?;
1163
- return Poll :: Ready ( Some ( Ok ( out ) ) ) ;
1188
+ let record_batch = self . filter_joined_batch ( ) ?;
1189
+ return Poll :: Ready ( Some ( Ok ( record_batch ) ) ) ;
1164
1190
} else {
1165
1191
let record_batch = self . output_record_batch_and_reset ( ) ?;
1166
1192
return Poll :: Ready ( Some ( Ok ( record_batch) ) ) ;
1167
1193
}
1194
+ } else if self . output . num_rows ( ) > 0 {
1195
+ // if processed but still not outputted because it didn't hit batch size before
1196
+ let schema = self . output . schema ( ) ;
1197
+ let record_batch = std:: mem:: replace (
1198
+ & mut self . output ,
1199
+ RecordBatch :: new_empty ( schema) ,
1200
+ ) ;
1201
+ return Poll :: Ready ( Some ( Ok ( record_batch) ) ) ;
1168
1202
} else {
1169
1203
return Poll :: Ready ( None ) ;
1170
1204
}
@@ -1197,7 +1231,7 @@ impl SortMergeJoinStream {
1197
1231
state : SortMergeJoinState :: Init ,
1198
1232
sort_options,
1199
1233
null_equals_null,
1200
- schema,
1234
+ schema : Arc :: clone ( & schema ) ,
1201
1235
streamed_schema : Arc :: clone ( & streamed_schema) ,
1202
1236
buffered_schema,
1203
1237
streamed,
@@ -1212,12 +1246,13 @@ impl SortMergeJoinStream {
1212
1246
on_streamed,
1213
1247
on_buffered,
1214
1248
filter,
1215
- output_record_batches : JoinedRecordBatches {
1249
+ staging_output_record_batches : JoinedRecordBatches {
1216
1250
batches : vec ! [ ] ,
1217
1251
filter_mask : BooleanBuilder :: new ( ) ,
1218
1252
row_indices : UInt64Builder :: new ( ) ,
1219
1253
batch_ids : vec ! [ ] ,
1220
1254
} ,
1255
+ output : RecordBatch :: new_empty ( schema) ,
1221
1256
output_size : 0 ,
1222
1257
batch_size,
1223
1258
join_type,
@@ -1607,17 +1642,20 @@ impl SortMergeJoinStream {
1607
1642
buffered_batch,
1608
1643
) ? {
1609
1644
let num_rows = record_batch. num_rows ( ) ;
1610
- self . output_record_batches
1645
+ self . staging_output_record_batches
1611
1646
. filter_mask
1612
1647
. append_nulls ( num_rows) ;
1613
- self . output_record_batches
1648
+ self . staging_output_record_batches
1614
1649
. row_indices
1615
1650
. append_nulls ( num_rows) ;
1616
- self . output_record_batches
1617
- . batch_ids
1618
- . resize ( self . output_record_batches . batch_ids . len ( ) + num_rows, 0 ) ;
1651
+ self . staging_output_record_batches . batch_ids . resize (
1652
+ self . staging_output_record_batches . batch_ids . len ( ) + num_rows,
1653
+ 0 ,
1654
+ ) ;
1619
1655
1620
- self . output_record_batches . batches . push ( record_batch) ;
1656
+ self . staging_output_record_batches
1657
+ . batches
1658
+ . push ( record_batch) ;
1621
1659
}
1622
1660
buffered_batch. null_joined . clear ( ) ;
1623
1661
}
@@ -1651,16 +1689,19 @@ impl SortMergeJoinStream {
1651
1689
) ? {
1652
1690
let num_rows = record_batch. num_rows ( ) ;
1653
1691
1654
- self . output_record_batches
1692
+ self . staging_output_record_batches
1655
1693
. filter_mask
1656
1694
. append_nulls ( num_rows) ;
1657
- self . output_record_batches
1695
+ self . staging_output_record_batches
1658
1696
. row_indices
1659
1697
. append_nulls ( num_rows) ;
1660
- self . output_record_batches
1661
- . batch_ids
1662
- . resize ( self . output_record_batches . batch_ids . len ( ) + num_rows, 0 ) ;
1663
- self . output_record_batches . batches . push ( record_batch) ;
1698
+ self . staging_output_record_batches . batch_ids . resize (
1699
+ self . staging_output_record_batches . batch_ids . len ( ) + num_rows,
1700
+ 0 ,
1701
+ ) ;
1702
+ self . staging_output_record_batches
1703
+ . batches
1704
+ . push ( record_batch) ;
1664
1705
}
1665
1706
buffered_batch. join_filter_not_matched_map . clear ( ) ;
1666
1707
@@ -1792,20 +1833,29 @@ impl SortMergeJoinStream {
1792
1833
| JoinType :: LeftMark
1793
1834
| JoinType :: Full
1794
1835
) {
1795
- self . output_record_batches . batches . push ( output_batch) ;
1836
+ self . staging_output_record_batches
1837
+ . batches
1838
+ . push ( output_batch) ;
1796
1839
} else {
1797
1840
let filtered_batch = filter_record_batch ( & output_batch, & mask) ?;
1798
- self . output_record_batches . batches . push ( filtered_batch) ;
1841
+ self . staging_output_record_batches
1842
+ . batches
1843
+ . push ( filtered_batch) ;
1799
1844
}
1800
1845
1801
1846
if !matches ! ( self . join_type, JoinType :: Full ) {
1802
- self . output_record_batches . filter_mask . extend ( & mask) ;
1847
+ self . staging_output_record_batches . filter_mask . extend ( & mask) ;
1803
1848
} else {
1804
- self . output_record_batches . filter_mask . extend ( pre_mask) ;
1849
+ self . staging_output_record_batches
1850
+ . filter_mask
1851
+ . extend ( pre_mask) ;
1805
1852
}
1806
- self . output_record_batches . row_indices . extend ( & left_indices) ;
1807
- self . output_record_batches . batch_ids . resize (
1808
- self . output_record_batches . batch_ids . len ( ) + left_indices. len ( ) ,
1853
+ self . staging_output_record_batches
1854
+ . row_indices
1855
+ . extend ( & left_indices) ;
1856
+ self . staging_output_record_batches . batch_ids . resize (
1857
+ self . staging_output_record_batches . batch_ids . len ( )
1858
+ + left_indices. len ( ) ,
1809
1859
self . streamed_batch_counter . load ( Relaxed ) ,
1810
1860
) ;
1811
1861
@@ -1837,10 +1887,14 @@ impl SortMergeJoinStream {
1837
1887
}
1838
1888
}
1839
1889
} else {
1840
- self . output_record_batches . batches . push ( output_batch) ;
1890
+ self . staging_output_record_batches
1891
+ . batches
1892
+ . push ( output_batch) ;
1841
1893
}
1842
1894
} else {
1843
- self . output_record_batches . batches . push ( output_batch) ;
1895
+ self . staging_output_record_batches
1896
+ . batches
1897
+ . push ( output_batch) ;
1844
1898
}
1845
1899
}
1846
1900
@@ -1851,7 +1905,7 @@ impl SortMergeJoinStream {
1851
1905
1852
1906
fn output_record_batch_and_reset ( & mut self ) -> Result < RecordBatch > {
1853
1907
let record_batch =
1854
- concat_batches ( & self . schema , & self . output_record_batches . batches ) ?;
1908
+ concat_batches ( & self . schema , & self . staging_output_record_batches . batches ) ?;
1855
1909
self . join_metrics . output_batches . add ( 1 ) ;
1856
1910
self . join_metrics . output_rows . add ( record_batch. num_rows ( ) ) ;
1857
1911
// If join filter exists, `self.output_size` is not accurate as we don't know the exact
@@ -1877,16 +1931,17 @@ impl SortMergeJoinStream {
1877
1931
| JoinType :: Full
1878
1932
) )
1879
1933
{
1880
- self . output_record_batches . batches . clear ( ) ;
1934
+ self . staging_output_record_batches . batches . clear ( ) ;
1881
1935
}
1882
1936
Ok ( record_batch)
1883
1937
}
1884
1938
1885
1939
fn filter_joined_batch ( & mut self ) -> Result < RecordBatch > {
1886
- let record_batch = self . output_record_batch_and_reset ( ) ?;
1887
- let mut out_indices = self . output_record_batches . row_indices . finish ( ) ;
1888
- let mut out_mask = self . output_record_batches . filter_mask . finish ( ) ;
1889
- let mut batch_ids = & self . output_record_batches . batch_ids ;
1940
+ let record_batch =
1941
+ concat_batches ( & self . schema , & self . staging_output_record_batches . batches ) ?;
1942
+ let mut out_indices = self . staging_output_record_batches . row_indices . finish ( ) ;
1943
+ let mut out_mask = self . staging_output_record_batches . filter_mask . finish ( ) ;
1944
+ let mut batch_ids = & self . staging_output_record_batches . batch_ids ;
1890
1945
let default_batch_ids = vec ! [ 0 ; record_batch. num_rows( ) ] ;
1891
1946
1892
1947
// If only nulls come in and indices sizes doesn't match with expected record batch count
@@ -1901,7 +1956,7 @@ impl SortMergeJoinStream {
1901
1956
}
1902
1957
1903
1958
if out_mask. is_empty ( ) {
1904
- self . output_record_batches . batches . clear ( ) ;
1959
+ self . staging_output_record_batches . batches . clear ( ) ;
1905
1960
return Ok ( record_batch) ;
1906
1961
}
1907
1962
@@ -2044,7 +2099,7 @@ impl SortMergeJoinStream {
2044
2099
) ?;
2045
2100
}
2046
2101
2047
- self . output_record_batches . clear ( ) ;
2102
+ self . staging_output_record_batches . clear ( ) ;
2048
2103
2049
2104
Ok ( filtered_record_batch)
2050
2105
}
0 commit comments