Skip to content

Commit 117901b

Browse files
committed
Extract CoalesceBatchesStream to a struct
1 parent 77311a5 commit 117901b

File tree

1 file changed

+142
-113
lines changed

1 file changed

+142
-113
lines changed

datafusion/physical-plan/src/coalesce_batches.rs

Lines changed: 142 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
use std::any::Any;
2222
use std::pin::Pin;
2323
use std::sync::Arc;
24-
use std::task::{Context, Poll};
24+
use std::task::{ready, Context, Poll};
2525

2626
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
2727
use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics};
@@ -146,10 +146,7 @@ impl ExecutionPlan for CoalesceBatchesExec {
146146
) -> Result<SendableRecordBatchStream> {
147147
Ok(Box::pin(CoalesceBatchesStream {
148148
input: self.input.execute(partition, context)?,
149-
schema: self.input.schema(),
150-
target_batch_size: self.target_batch_size,
151-
buffer: Vec::new(),
152-
buffered_rows: 0,
149+
coalescer: BatchCoalescer::new(self.input.schema(), self.target_batch_size),
153150
is_closed: false,
154151
baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
155152
}))
@@ -167,14 +164,8 @@ impl ExecutionPlan for CoalesceBatchesExec {
167164
struct CoalesceBatchesStream {
168165
/// The input plan
169166
input: SendableRecordBatchStream,
170-
/// The input schema
171-
schema: SchemaRef,
172-
/// Minimum number of rows for coalesces batches
173-
target_batch_size: usize,
174-
/// Buffered batches
175-
buffer: Vec<RecordBatch>,
176-
/// Buffered row count
177-
buffered_rows: usize,
167+
/// Buffer for combining batches
168+
coalescer: BatchCoalescer,
178169
/// Whether the stream has finished returning all of its data or not
179170
is_closed: bool,
180171
/// Execution metrics
@@ -213,66 +204,35 @@ impl CoalesceBatchesStream {
213204
let input_batch = self.input.poll_next_unpin(cx);
214205
// records time on drop
215206
let _timer = cloned_time.timer();
216-
match input_batch {
217-
Poll::Ready(x) => match x {
218-
Some(Ok(batch)) => {
219-
if batch.num_rows() >= self.target_batch_size
220-
&& self.buffer.is_empty()
221-
{
222-
return Poll::Ready(Some(Ok(batch)));
223-
} else if batch.num_rows() == 0 {
224-
// discard empty batches
225-
} else {
226-
// add to the buffered batches
227-
self.buffered_rows += batch.num_rows();
228-
self.buffer.push(batch);
229-
// check to see if we have enough batches yet
230-
if self.buffered_rows >= self.target_batch_size {
231-
// combine the batches and return
232-
let batch = concat_batches(
233-
&self.schema,
234-
&self.buffer,
235-
self.buffered_rows,
236-
)?;
237-
// reset buffer state
238-
self.buffer.clear();
239-
self.buffered_rows = 0;
240-
// return batch
241-
return Poll::Ready(Some(Ok(batch)));
242-
}
243-
}
244-
}
245-
None => {
246-
self.is_closed = true;
247-
// we have reached the end of the input stream but there could still
248-
// be buffered batches
249-
if self.buffer.is_empty() {
250-
return Poll::Ready(None);
251-
} else {
252-
// combine the batches and return
253-
let batch = concat_batches(
254-
&self.schema,
255-
&self.buffer,
256-
self.buffered_rows,
257-
)?;
258-
// reset buffer state
259-
self.buffer.clear();
260-
self.buffered_rows = 0;
261-
// return batch
262-
return Poll::Ready(Some(Ok(batch)));
263-
}
207+
match ready!(input_batch) {
208+
Some(result) => {
209+
let Ok(input_batch) = result else {
210+
return Poll::Ready(Some(result)); // pass back error
211+
};
212+
// Buffer the batch and either get more input if not enough
213+
// rows yet or output
214+
match self.coalescer.push_batch(input_batch) {
215+
Ok(None) => continue,
216+
res => return Poll::Ready(res.transpose()),
264217
}
265-
other => return Poll::Ready(other),
266-
},
267-
Poll::Pending => return Poll::Pending,
218+
}
219+
None => {
220+
self.is_closed = true;
221+
// we have reached the end of the input stream but there could still
222+
// be buffered batches
223+
return match self.coalescer.finish() {
224+
Ok(None) => Poll::Ready(None),
225+
res => Poll::Ready(res.transpose()),
226+
};
227+
}
268228
}
269229
}
270230
}
271231
}
272232

273233
impl RecordBatchStream for CoalesceBatchesStream {
274234
fn schema(&self) -> SchemaRef {
275-
Arc::clone(&self.schema)
235+
self.coalescer.schema()
276236
}
277237
}
278238

@@ -290,26 +250,106 @@ pub fn concat_batches(
290250
arrow::compute::concat_batches(schema, batches)
291251
}
292252

253+
/// Concatenating multiple record batches into larger batches
254+
///
255+
/// TODO ASCII ART
256+
///
257+
/// Notes:
258+
///
259+
/// 1. The output is exactly the same order as the input rows
260+
///
261+
/// 2. The output is a sequence of batches, with all but the last being at least
262+
/// `target_batch_size` rows.
263+
///
264+
/// 3. Eventually this may also be able to handle other optimizations such as a
265+
/// combined filter/coalesce operation.
266+
#[derive(Debug)]
267+
struct BatchCoalescer {
268+
/// The input schema
269+
schema: SchemaRef,
270+
/// Minimum number of rows for coalesces batches
271+
target_batch_size: usize,
272+
/// Buffered batches
273+
buffer: Vec<RecordBatch>,
274+
/// Buffered row count
275+
buffered_rows: usize,
276+
}
277+
278+
impl BatchCoalescer {
279+
/// Create a new BatchCoalescer that produces batches of at least `target_batch_size` rows
280+
fn new(schema: SchemaRef, target_batch_size: usize) -> Self {
281+
Self {
282+
schema,
283+
target_batch_size,
284+
buffer: vec![],
285+
buffered_rows: 0,
286+
}
287+
}
288+
289+
/// Return the schema of the output batches
290+
fn schema(&self) -> SchemaRef {
291+
Arc::clone(&self.schema)
292+
}
293+
294+
/// Add a batch to the coalescer, returning a batch if the target batch size is reached
295+
fn push_batch(&mut self, batch: RecordBatch) -> Result<Option<RecordBatch>> {
296+
if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() {
297+
return Ok(Some(batch));
298+
}
299+
// discard empty batches
300+
if batch.num_rows() == 0 {
301+
return Ok(None);
302+
}
303+
// add to the buffered batches
304+
self.buffered_rows += batch.num_rows();
305+
self.buffer.push(batch);
306+
// check to see if we have enough batches yet
307+
let batch = if self.buffered_rows >= self.target_batch_size {
308+
// combine the batches and return
309+
let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?;
310+
// reset buffer state
311+
self.buffer.clear();
312+
self.buffered_rows = 0;
313+
// return batch
314+
Some(batch)
315+
} else {
316+
None
317+
};
318+
Ok(batch)
319+
}
320+
321+
/// Finish the coalescing process, returning all buffered data as a final,
322+
/// single batch, if any
323+
fn finish(&mut self) -> Result<Option<RecordBatch>> {
324+
if self.buffer.is_empty() {
325+
Ok(None)
326+
} else {
327+
// combine the batches and return
328+
let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?;
329+
// reset buffer state
330+
self.buffer.clear();
331+
self.buffered_rows = 0;
332+
// return batch
333+
Ok(Some(batch))
334+
}
335+
}
336+
}
337+
293338
#[cfg(test)]
294339
mod tests {
295340
use super::*;
296-
use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning};
297-
298341
use arrow::datatypes::{DataType, Field, Schema};
299342
use arrow_array::UInt32Array;
300343

301344
#[tokio::test(flavor = "multi_thread")]
302345
async fn test_concat_batches() -> Result<()> {
303-
let schema = test_schema();
304-
let partition = create_vec_batches(&schema, 10);
305-
let partitions = vec![partition];
306-
307-
let output_partitions = coalesce_batches(&schema, partitions, 21).await?;
308-
assert_eq!(1, output_partitions.len());
346+
let Scenario { schema, batch } = uint32_scenario();
309347

310348
// input is 10 batches x 8 rows (80 rows)
349+
let input = std::iter::repeat(batch).take(10);
350+
311351
// expected output is batches of at least 20 rows (except for the final batch)
312-
let batches = &output_partitions[0];
352+
let batches = do_coalesce_batches(&schema, input, 21);
313353
assert_eq!(4, batches.len());
314354
assert_eq!(24, batches[0].num_rows());
315355
assert_eq!(24, batches[1].num_rows());
@@ -319,54 +359,43 @@ mod tests {
319359
Ok(())
320360
}
321361

322-
fn test_schema() -> Arc<Schema> {
323-
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
324-
}
325-
326-
async fn coalesce_batches(
362+
// Coalesce the batches with a BatchCoalescer function with the given input
363+
// and target batch size returning the resulting batches
364+
fn do_coalesce_batches(
327365
schema: &SchemaRef,
328-
input_partitions: Vec<Vec<RecordBatch>>,
366+
input: impl IntoIterator<Item = RecordBatch>,
329367
target_batch_size: usize,
330-
) -> Result<Vec<Vec<RecordBatch>>> {
368+
) -> Vec<RecordBatch> {
331369
// create physical plan
332-
let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?;
333-
let exec =
334-
RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?;
335-
let exec: Arc<dyn ExecutionPlan> =
336-
Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size));
337-
338-
// execute and collect results
339-
let output_partition_count = exec.output_partitioning().partition_count();
340-
let mut output_partitions = Vec::with_capacity(output_partition_count);
341-
for i in 0..output_partition_count {
342-
// execute this *output* partition and collect all batches
343-
let task_ctx = Arc::new(TaskContext::default());
344-
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
345-
let mut batches = vec![];
346-
while let Some(result) = stream.next().await {
347-
batches.push(result?);
348-
}
349-
output_partitions.push(batches);
370+
let mut coalescer = BatchCoalescer::new(Arc::clone(schema), target_batch_size);
371+
let mut output_batches: Vec<_> = input
372+
.into_iter()
373+
.filter_map(|batch| coalescer.push_batch(batch).unwrap())
374+
.collect();
375+
if let Some(batch) = coalescer.finish().unwrap() {
376+
output_batches.push(batch);
350377
}
351-
Ok(output_partitions)
378+
output_batches
352379
}
353380

354-
/// Create vector batches
355-
fn create_vec_batches(schema: &Schema, n: usize) -> Vec<RecordBatch> {
356-
let batch = create_batch(schema);
357-
let mut vec = Vec::with_capacity(n);
358-
for _ in 0..n {
359-
vec.push(batch.clone());
360-
}
361-
vec
381+
/// Test scenario
382+
#[derive(Debug)]
383+
struct Scenario {
384+
schema: Arc<Schema>,
385+
batch: RecordBatch,
362386
}
363387

364-
/// Create batch
365-
fn create_batch(schema: &Schema) -> RecordBatch {
366-
RecordBatch::try_new(
367-
Arc::new(schema.clone()),
388+
/// a batch of 8 rows of UInt32
389+
fn uint32_scenario() -> Scenario {
390+
let schema =
391+
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
392+
393+
let batch = RecordBatch::try_new(
394+
Arc::clone(&schema),
368395
vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
369396
)
370-
.unwrap()
397+
.unwrap();
398+
399+
Scenario { schema, batch }
371400
}
372401
}

0 commit comments

Comments
 (0)