21
21
use std:: any:: Any ;
22
22
use std:: pin:: Pin ;
23
23
use std:: sync:: Arc ;
24
- use std:: task:: { Context , Poll } ;
24
+ use std:: task:: { ready , Context , Poll } ;
25
25
26
26
use super :: metrics:: { BaselineMetrics , ExecutionPlanMetricsSet , MetricsSet } ;
27
27
use super :: { DisplayAs , ExecutionPlanProperties , PlanProperties , Statistics } ;
@@ -146,10 +146,7 @@ impl ExecutionPlan for CoalesceBatchesExec {
146
146
) -> Result < SendableRecordBatchStream > {
147
147
Ok ( Box :: pin ( CoalesceBatchesStream {
148
148
input : self . input . execute ( partition, context) ?,
149
- schema : self . input . schema ( ) ,
150
- target_batch_size : self . target_batch_size ,
151
- buffer : Vec :: new ( ) ,
152
- buffered_rows : 0 ,
149
+ coalescer : BatchCoalescer :: new ( self . input . schema ( ) , self . target_batch_size ) ,
153
150
is_closed : false ,
154
151
baseline_metrics : BaselineMetrics :: new ( & self . metrics , partition) ,
155
152
} ) )
@@ -167,14 +164,8 @@ impl ExecutionPlan for CoalesceBatchesExec {
167
164
struct CoalesceBatchesStream {
168
165
/// The input plan
169
166
input : SendableRecordBatchStream ,
170
- /// The input schema
171
- schema : SchemaRef ,
172
- /// Minimum number of rows for coalesces batches
173
- target_batch_size : usize ,
174
- /// Buffered batches
175
- buffer : Vec < RecordBatch > ,
176
- /// Buffered row count
177
- buffered_rows : usize ,
167
+ /// Buffer for combining batches
168
+ coalescer : BatchCoalescer ,
178
169
/// Whether the stream has finished returning all of its data or not
179
170
is_closed : bool ,
180
171
/// Execution metrics
@@ -213,66 +204,35 @@ impl CoalesceBatchesStream {
213
204
let input_batch = self . input . poll_next_unpin ( cx) ;
214
205
// records time on drop
215
206
let _timer = cloned_time. timer ( ) ;
216
- match input_batch {
217
- Poll :: Ready ( x) => match x {
218
- Some ( Ok ( batch) ) => {
219
- if batch. num_rows ( ) >= self . target_batch_size
220
- && self . buffer . is_empty ( )
221
- {
222
- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
223
- } else if batch. num_rows ( ) == 0 {
224
- // discard empty batches
225
- } else {
226
- // add to the buffered batches
227
- self . buffered_rows += batch. num_rows ( ) ;
228
- self . buffer . push ( batch) ;
229
- // check to see if we have enough batches yet
230
- if self . buffered_rows >= self . target_batch_size {
231
- // combine the batches and return
232
- let batch = concat_batches (
233
- & self . schema ,
234
- & self . buffer ,
235
- self . buffered_rows ,
236
- ) ?;
237
- // reset buffer state
238
- self . buffer . clear ( ) ;
239
- self . buffered_rows = 0 ;
240
- // return batch
241
- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
242
- }
243
- }
244
- }
245
- None => {
246
- self . is_closed = true ;
247
- // we have reached the end of the input stream but there could still
248
- // be buffered batches
249
- if self . buffer . is_empty ( ) {
250
- return Poll :: Ready ( None ) ;
251
- } else {
252
- // combine the batches and return
253
- let batch = concat_batches (
254
- & self . schema ,
255
- & self . buffer ,
256
- self . buffered_rows ,
257
- ) ?;
258
- // reset buffer state
259
- self . buffer . clear ( ) ;
260
- self . buffered_rows = 0 ;
261
- // return batch
262
- return Poll :: Ready ( Some ( Ok ( batch) ) ) ;
263
- }
207
+ match ready ! ( input_batch) {
208
+ Some ( result) => {
209
+ let Ok ( input_batch) = result else {
210
+ return Poll :: Ready ( Some ( result) ) ; // pass back error
211
+ } ;
212
+ // Buffer the batch and either get more input if not enough
213
+ // rows yet or output
214
+ match self . coalescer . push_batch ( input_batch) {
215
+ Ok ( None ) => continue ,
216
+ res => return Poll :: Ready ( res. transpose ( ) ) ,
264
217
}
265
- other => return Poll :: Ready ( other) ,
266
- } ,
267
- Poll :: Pending => return Poll :: Pending ,
218
+ }
219
+ None => {
220
+ self . is_closed = true ;
221
+ // we have reached the end of the input stream but there could still
222
+ // be buffered batches
223
+ return match self . coalescer . finish ( ) {
224
+ Ok ( None ) => Poll :: Ready ( None ) ,
225
+ res => Poll :: Ready ( res. transpose ( ) ) ,
226
+ } ;
227
+ }
268
228
}
269
229
}
270
230
}
271
231
}
272
232
273
233
impl RecordBatchStream for CoalesceBatchesStream {
274
234
fn schema ( & self ) -> SchemaRef {
275
- Arc :: clone ( & self . schema )
235
+ self . coalescer . schema ( )
276
236
}
277
237
}
278
238
@@ -290,26 +250,106 @@ pub fn concat_batches(
290
250
arrow:: compute:: concat_batches ( schema, batches)
291
251
}
292
252
253
+ /// Concatenating multiple record batches into larger batches
254
+ ///
255
+ /// TODO ASCII ART
256
+ ///
257
+ /// Notes:
258
+ ///
259
+ /// 1. The output is exactly the same order as the input rows
260
+ ///
261
+ /// 2. The output is a sequence of batches, with all but the last being at least
262
+ /// `target_batch_size` rows.
263
+ ///
264
+ /// 3. Eventually this may also be able to handle other optimizations such as a
265
+ /// combined filter/coalesce operation.
266
+ #[ derive( Debug ) ]
267
+ struct BatchCoalescer {
268
+ /// The input schema
269
+ schema : SchemaRef ,
270
+ /// Minimum number of rows for coalesces batches
271
+ target_batch_size : usize ,
272
+ /// Buffered batches
273
+ buffer : Vec < RecordBatch > ,
274
+ /// Buffered row count
275
+ buffered_rows : usize ,
276
+ }
277
+
278
+ impl BatchCoalescer {
279
+ /// Create a new BatchCoalescer that produces batches of at least `target_batch_size` rows
280
+ fn new ( schema : SchemaRef , target_batch_size : usize ) -> Self {
281
+ Self {
282
+ schema,
283
+ target_batch_size,
284
+ buffer : vec ! [ ] ,
285
+ buffered_rows : 0 ,
286
+ }
287
+ }
288
+
289
+ /// Return the schema of the output batches
290
+ fn schema ( & self ) -> SchemaRef {
291
+ Arc :: clone ( & self . schema )
292
+ }
293
+
294
+ /// Add a batch to the coalescer, returning a batch if the target batch size is reached
295
+ fn push_batch ( & mut self , batch : RecordBatch ) -> Result < Option < RecordBatch > > {
296
+ if batch. num_rows ( ) >= self . target_batch_size && self . buffer . is_empty ( ) {
297
+ return Ok ( Some ( batch) ) ;
298
+ }
299
+ // discard empty batches
300
+ if batch. num_rows ( ) == 0 {
301
+ return Ok ( None ) ;
302
+ }
303
+ // add to the buffered batches
304
+ self . buffered_rows += batch. num_rows ( ) ;
305
+ self . buffer . push ( batch) ;
306
+ // check to see if we have enough batches yet
307
+ let batch = if self . buffered_rows >= self . target_batch_size {
308
+ // combine the batches and return
309
+ let batch = concat_batches ( & self . schema , & self . buffer , self . buffered_rows ) ?;
310
+ // reset buffer state
311
+ self . buffer . clear ( ) ;
312
+ self . buffered_rows = 0 ;
313
+ // return batch
314
+ Some ( batch)
315
+ } else {
316
+ None
317
+ } ;
318
+ Ok ( batch)
319
+ }
320
+
321
+ /// Finish the coalescing process, returning all buffered data as a final,
322
+ /// single batch, if any
323
+ fn finish ( & mut self ) -> Result < Option < RecordBatch > > {
324
+ if self . buffer . is_empty ( ) {
325
+ Ok ( None )
326
+ } else {
327
+ // combine the batches and return
328
+ let batch = concat_batches ( & self . schema , & self . buffer , self . buffered_rows ) ?;
329
+ // reset buffer state
330
+ self . buffer . clear ( ) ;
331
+ self . buffered_rows = 0 ;
332
+ // return batch
333
+ Ok ( Some ( batch) )
334
+ }
335
+ }
336
+ }
337
+
293
338
#[ cfg( test) ]
294
339
mod tests {
295
340
use super :: * ;
296
- use crate :: { memory:: MemoryExec , repartition:: RepartitionExec , Partitioning } ;
297
-
298
341
use arrow:: datatypes:: { DataType , Field , Schema } ;
299
342
use arrow_array:: UInt32Array ;
300
343
301
344
#[ tokio:: test( flavor = "multi_thread" ) ]
302
345
async fn test_concat_batches ( ) -> Result < ( ) > {
303
- let schema = test_schema ( ) ;
304
- let partition = create_vec_batches ( & schema, 10 ) ;
305
- let partitions = vec ! [ partition] ;
306
-
307
- let output_partitions = coalesce_batches ( & schema, partitions, 21 ) . await ?;
308
- assert_eq ! ( 1 , output_partitions. len( ) ) ;
346
+ let Scenario { schema, batch } = uint32_scenario ( ) ;
309
347
310
348
// input is 10 batches x 8 rows (80 rows)
349
+ let input = std:: iter:: repeat ( batch) . take ( 10 ) ;
350
+
311
351
// expected output is batches of at least 20 rows (except for the final batch)
312
- let batches = & output_partitions [ 0 ] ;
352
+ let batches = do_coalesce_batches ( & schema , input , 21 ) ;
313
353
assert_eq ! ( 4 , batches. len( ) ) ;
314
354
assert_eq ! ( 24 , batches[ 0 ] . num_rows( ) ) ;
315
355
assert_eq ! ( 24 , batches[ 1 ] . num_rows( ) ) ;
@@ -319,54 +359,43 @@ mod tests {
319
359
Ok ( ( ) )
320
360
}
321
361
322
- fn test_schema ( ) -> Arc < Schema > {
323
- Arc :: new ( Schema :: new ( vec ! [ Field :: new( "c0" , DataType :: UInt32 , false ) ] ) )
324
- }
325
-
326
- async fn coalesce_batches (
362
+ // Coalesce the batches with a BatchCoalescer function with the given input
363
+ // and target batch size returning the resulting batches
364
+ fn do_coalesce_batches (
327
365
schema : & SchemaRef ,
328
- input_partitions : Vec < Vec < RecordBatch > > ,
366
+ input : impl IntoIterator < Item = RecordBatch > ,
329
367
target_batch_size : usize ,
330
- ) -> Result < Vec < Vec < RecordBatch > > > {
368
+ ) -> Vec < RecordBatch > {
331
369
// create physical plan
332
- let exec = MemoryExec :: try_new ( & input_partitions, Arc :: clone ( schema) , None ) ?;
333
- let exec =
334
- RepartitionExec :: try_new ( Arc :: new ( exec) , Partitioning :: RoundRobinBatch ( 1 ) ) ?;
335
- let exec: Arc < dyn ExecutionPlan > =
336
- Arc :: new ( CoalesceBatchesExec :: new ( Arc :: new ( exec) , target_batch_size) ) ;
337
-
338
- // execute and collect results
339
- let output_partition_count = exec. output_partitioning ( ) . partition_count ( ) ;
340
- let mut output_partitions = Vec :: with_capacity ( output_partition_count) ;
341
- for i in 0 ..output_partition_count {
342
- // execute this *output* partition and collect all batches
343
- let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
344
- let mut stream = exec. execute ( i, Arc :: clone ( & task_ctx) ) ?;
345
- let mut batches = vec ! [ ] ;
346
- while let Some ( result) = stream. next ( ) . await {
347
- batches. push ( result?) ;
348
- }
349
- output_partitions. push ( batches) ;
370
+ let mut coalescer = BatchCoalescer :: new ( Arc :: clone ( schema) , target_batch_size) ;
371
+ let mut output_batches: Vec < _ > = input
372
+ . into_iter ( )
373
+ . filter_map ( |batch| coalescer. push_batch ( batch) . unwrap ( ) )
374
+ . collect ( ) ;
375
+ if let Some ( batch) = coalescer. finish ( ) . unwrap ( ) {
376
+ output_batches. push ( batch) ;
350
377
}
351
- Ok ( output_partitions )
378
+ output_batches
352
379
}
353
380
354
- /// Create vector batches
355
- fn create_vec_batches ( schema : & Schema , n : usize ) -> Vec < RecordBatch > {
356
- let batch = create_batch ( schema) ;
357
- let mut vec = Vec :: with_capacity ( n) ;
358
- for _ in 0 ..n {
359
- vec. push ( batch. clone ( ) ) ;
360
- }
361
- vec
381
+ /// Test scenario
382
+ #[ derive( Debug ) ]
383
+ struct Scenario {
384
+ schema : Arc < Schema > ,
385
+ batch : RecordBatch ,
362
386
}
363
387
364
- /// Create batch
365
- fn create_batch ( schema : & Schema ) -> RecordBatch {
366
- RecordBatch :: try_new (
367
- Arc :: new ( schema. clone ( ) ) ,
388
+ /// a batch of 8 rows of UInt32
389
+ fn uint32_scenario ( ) -> Scenario {
390
+ let schema =
391
+ Arc :: new ( Schema :: new ( vec ! [ Field :: new( "c0" , DataType :: UInt32 , false ) ] ) ) ;
392
+
393
+ let batch = RecordBatch :: try_new (
394
+ Arc :: clone ( & schema) ,
368
395
vec ! [ Arc :: new( UInt32Array :: from( vec![ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) ) ] ,
369
396
)
370
- . unwrap ( )
397
+ . unwrap ( ) ;
398
+
399
+ Scenario { schema, batch }
371
400
}
372
401
}
0 commit comments