Skip to content

Commit f24a724

Browse files
mustafasrepoozankabakakurmustafa
authored
Order Preserving RepartitionExec Implementation (#6742)
* Write tests for functionality * Implement sort preserving repartition exec * Minor changes * Implement second design (per partition merge) * Simplifications * Address reviews * Move the fuzz test to appropriate folder, improve comments * Decrease code duplication * simplifications * Update comment --------- Co-authored-by: Mehmet Ozan Kabak <[email protected]> Co-authored-by: Mustafa Akur <[email protected]>
1 parent 1522e7a commit f24a724

File tree

4 files changed

+405
-25
lines changed

4 files changed

+405
-25
lines changed

datafusion/core/src/physical_plan/repartition/distributor_channels.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ pub fn channels<T>(
8383
(senders, receivers)
8484
}
8585

86+
type PartitionAwareSenders<T> = Vec<Vec<DistributionSender<T>>>;
87+
type PartitionAwareReceivers<T> = Vec<Vec<DistributionReceiver<T>>>;
88+
89+
/// Create `n_out` empty channels for each of the `n_in` inputs.
90+
/// This way, each distinct partition will communicate via a dedicated channel.
91+
/// This SPSC structure enables us to track which partition input data comes from.
92+
pub fn partition_aware_channels<T>(
93+
n_in: usize,
94+
n_out: usize,
95+
) -> (PartitionAwareSenders<T>, PartitionAwareReceivers<T>) {
96+
(0..n_in).map(|_| channels(n_out)).unzip()
97+
}
98+
8699
/// Erroring during [send](DistributionSender::send).
87100
///
88101
/// This occurs when the [receiver](DistributionReceiver) is gone.

datafusion/core/src/physical_plan/repartition/mod.rs

Lines changed: 154 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616
// under the License.
1717

1818
//! The repartition operator maps N input partitions to M output partitions based on a
19-
//! partitioning scheme.
19+
//! partitioning scheme (according to flag `preserve_order` ordering can be preserved during
20+
//! repartitioning if its input is ordered).
2021
2122
use std::pin::Pin;
2223
use std::sync::Arc;
2324
use std::task::{Context, Poll};
2425
use std::{any::Any, vec};
2526

2627
use crate::physical_plan::hash_utils::create_hashes;
27-
use crate::physical_plan::repartition::distributor_channels::channels;
28+
use crate::physical_plan::repartition::distributor_channels::{
29+
channels, partition_aware_channels,
30+
};
2831
use crate::physical_plan::{
2932
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning, Statistics,
3033
};
@@ -42,6 +45,9 @@ use super::expressions::PhysicalSortExpr;
4245
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
4346
use super::{RecordBatchStream, SendableRecordBatchStream};
4447

48+
use crate::physical_plan::common::transpose;
49+
use crate::physical_plan::metrics::BaselineMetrics;
50+
use crate::physical_plan::sorts::streaming_merge;
4551
use datafusion_execution::TaskContext;
4652
use datafusion_physical_expr::PhysicalExpr;
4753
use futures::stream::Stream;
@@ -53,6 +59,8 @@ use tokio::task::JoinHandle;
5359
mod distributor_channels;
5460

5561
type MaybeBatch = Option<Result<RecordBatch>>;
62+
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
63+
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
5664

5765
/// Inner state of [`RepartitionExec`].
5866
#[derive(Debug)]
@@ -62,8 +70,8 @@ struct RepartitionExecState {
6270
channels: HashMap<
6371
usize,
6472
(
65-
DistributionSender<MaybeBatch>,
66-
DistributionReceiver<MaybeBatch>,
73+
InputPartitionsToCurrentPartitionSender,
74+
InputPartitionsToCurrentPartitionReceiver,
6775
SharedMemoryReservation,
6876
),
6977
>,
@@ -245,6 +253,9 @@ pub struct RepartitionExec {
245253

246254
/// Execution metrics
247255
metrics: ExecutionPlanMetricsSet,
256+
257+
/// Boolean flag to decide whether to preserve ordering
258+
preserve_order: bool,
248259
}
249260

250261
#[derive(Debug, Clone)]
@@ -298,6 +309,15 @@ impl RepartitionExec {
298309
pub fn partitioning(&self) -> &Partitioning {
299310
&self.partitioning
300311
}
312+
313+
/// Get name of the Executor
314+
pub fn name(&self) -> &str {
315+
if self.preserve_order {
316+
"SortPreservingRepartitionExec"
317+
} else {
318+
"RepartitionExec"
319+
}
320+
}
301321
}
302322

303323
impl ExecutionPlan for RepartitionExec {
@@ -345,8 +365,12 @@ impl ExecutionPlan for RepartitionExec {
345365
}
346366

347367
fn maintains_input_order(&self) -> Vec<bool> {
348-
// We preserve ordering when input partitioning is 1
349-
vec![self.input().output_partitioning().partition_count() <= 1]
368+
if self.preserve_order {
369+
vec![true]
370+
} else {
371+
// We preserve ordering when input partitioning is 1
372+
vec![self.input().output_partitioning().partition_count() <= 1]
373+
}
350374
}
351375

352376
fn equivalence_properties(&self) -> EquivalenceProperties {
@@ -359,7 +383,8 @@ impl ExecutionPlan for RepartitionExec {
359383
context: Arc<TaskContext>,
360384
) -> Result<SendableRecordBatchStream> {
361385
trace!(
362-
"Start RepartitionExec::execute for partition: {}",
386+
"Start {}::execute for partition: {}",
387+
self.name(),
363388
partition
364389
);
365390
// lock mutexes
@@ -370,13 +395,29 @@ impl ExecutionPlan for RepartitionExec {
370395

371396
// if this is the first partition to be invoked then we need to set up initial state
372397
if state.channels.is_empty() {
373-
// create one channel per *output* partition
374-
// note we use a custom channel that ensures there is always data for each receiver
375-
// but limits the amount of buffering if required.
376-
let (txs, rxs) = channels(num_output_partitions);
398+
let (txs, rxs) = if self.preserve_order {
399+
let (txs, rxs) =
400+
partition_aware_channels(num_input_partitions, num_output_partitions);
401+
// Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition
402+
let txs = transpose(txs);
403+
let rxs = transpose(rxs);
404+
(txs, rxs)
405+
} else {
406+
// create one channel per *output* partition
407+
// note we use a custom channel that ensures there is always data for each receiver
408+
// but limits the amount of buffering if required.
409+
let (txs, rxs) = channels(num_output_partitions);
410+
// Clone sender for ech input partitions
411+
let txs = txs
412+
.into_iter()
413+
.map(|item| vec![item; num_input_partitions])
414+
.collect::<Vec<_>>();
415+
let rxs = rxs.into_iter().map(|item| vec![item]).collect::<Vec<_>>();
416+
(txs, rxs)
417+
};
377418
for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() {
378419
let reservation = Arc::new(Mutex::new(
379-
MemoryConsumer::new(format!("RepartitionExec[{partition}]"))
420+
MemoryConsumer::new(format!("{}[{partition}]", self.name()))
380421
.register(context.memory_pool()),
381422
));
382423
state.channels.insert(partition, (tx, rx, reservation));
@@ -389,7 +430,7 @@ impl ExecutionPlan for RepartitionExec {
389430
.channels
390431
.iter()
391432
.map(|(partition, (tx, _rx, reservation))| {
392-
(*partition, (tx.clone(), Arc::clone(reservation)))
433+
(*partition, (tx[i].clone(), Arc::clone(reservation)))
393434
})
394435
.collect();
395436

@@ -420,24 +461,53 @@ impl ExecutionPlan for RepartitionExec {
420461
}
421462

422463
trace!(
423-
"Before returning stream in RepartitionExec::execute for partition: {}",
464+
"Before returning stream in {}::execute for partition: {}",
465+
self.name(),
424466
partition
425467
);
426468

427469
// now return stream for the specified *output* partition which will
428470
// read from the channel
429-
let (_tx, rx, reservation) = state
471+
let (_tx, mut rx, reservation) = state
430472
.channels
431473
.remove(&partition)
432474
.expect("partition not used yet");
433-
Ok(Box::pin(RepartitionStream {
434-
num_input_partitions,
435-
num_input_partitions_processed: 0,
436-
schema: self.input.schema(),
437-
input: rx,
438-
drop_helper: Arc::clone(&state.abort_helper),
439-
reservation,
440-
}))
475+
476+
if self.preserve_order {
477+
// Store streams from all the input partitions:
478+
let input_streams = rx
479+
.into_iter()
480+
.map(|receiver| {
481+
Box::pin(PerPartitionStream {
482+
schema: self.schema(),
483+
receiver,
484+
drop_helper: Arc::clone(&state.abort_helper),
485+
reservation: reservation.clone(),
486+
}) as SendableRecordBatchStream
487+
})
488+
.collect::<Vec<_>>();
489+
// Note that receiver size (`rx.len()`) and `num_input_partitions` are same.
490+
491+
// Get existing ordering:
492+
let sort_exprs = self.input.output_ordering().unwrap_or(&[]);
493+
// Merge streams (while preserving ordering) coming from input partitions to this partition:
494+
streaming_merge(
495+
input_streams,
496+
self.schema(),
497+
sort_exprs,
498+
BaselineMetrics::new(&self.metrics, partition),
499+
context.session_config().batch_size(),
500+
)
501+
} else {
502+
Ok(Box::pin(RepartitionStream {
503+
num_input_partitions,
504+
num_input_partitions_processed: 0,
505+
schema: self.input.schema(),
506+
input: rx.swap_remove(0),
507+
drop_helper: Arc::clone(&state.abort_helper),
508+
reservation,
509+
}))
510+
}
441511
}
442512

443513
fn metrics(&self) -> Option<MetricsSet> {
@@ -453,7 +523,8 @@ impl ExecutionPlan for RepartitionExec {
453523
DisplayFormatType::Default | DisplayFormatType::Verbose => {
454524
write!(
455525
f,
456-
"RepartitionExec: partitioning={}, input_partitions={}",
526+
"{}: partitioning={}, input_partitions={}",
527+
self.name(),
457528
self.partitioning,
458529
self.input.output_partitioning().partition_count()
459530
)
@@ -480,9 +551,16 @@ impl RepartitionExec {
480551
abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
481552
})),
482553
metrics: ExecutionPlanMetricsSet::new(),
554+
preserve_order: false,
483555
})
484556
}
485557

558+
/// Set Order preserving flag
559+
pub fn with_preserve_order(mut self) -> Self {
560+
self.preserve_order = true;
561+
self
562+
}
563+
486564
/// Pulls data from the specified input plan, feeding it to the
487565
/// output partitions based on the desired partitioning
488566
///
@@ -575,7 +653,7 @@ impl RepartitionExec {
575653
/// channels.
576654
async fn wait_for_task(
577655
input_task: AbortOnDropSingle<Result<()>>,
578-
txs: HashMap<usize, DistributionSender<Option<Result<RecordBatch>>>>,
656+
txs: HashMap<usize, DistributionSender<MaybeBatch>>,
579657
) {
580658
// wait for completion, and propagate error
581659
// note we ignore errors on send (.ok) as that means the receiver has already shutdown.
@@ -681,6 +759,56 @@ impl RecordBatchStream for RepartitionStream {
681759
}
682760
}
683761

762+
/// This struct converts a receiver to a stream.
763+
/// Receiver receives data on an SPSC channel.
764+
struct PerPartitionStream {
765+
/// Schema wrapped by Arc
766+
schema: SchemaRef,
767+
768+
/// channel containing the repartitioned batches
769+
receiver: DistributionReceiver<MaybeBatch>,
770+
771+
/// Handle to ensure background tasks are killed when no longer needed.
772+
#[allow(dead_code)]
773+
drop_helper: Arc<AbortOnDropMany<()>>,
774+
775+
/// Memory reservation.
776+
reservation: SharedMemoryReservation,
777+
}
778+
779+
impl Stream for PerPartitionStream {
780+
type Item = Result<RecordBatch>;
781+
782+
fn poll_next(
783+
mut self: Pin<&mut Self>,
784+
cx: &mut Context<'_>,
785+
) -> Poll<Option<Self::Item>> {
786+
match self.receiver.recv().poll_unpin(cx) {
787+
Poll::Ready(Some(Some(v))) => {
788+
if let Ok(batch) = &v {
789+
self.reservation
790+
.lock()
791+
.shrink(batch.get_array_memory_size());
792+
}
793+
Poll::Ready(Some(v))
794+
}
795+
Poll::Ready(Some(None)) => {
796+
// Input partition has finished sending batches
797+
Poll::Ready(None)
798+
}
799+
Poll::Ready(None) => Poll::Ready(None),
800+
Poll::Pending => Poll::Pending,
801+
}
802+
}
803+
}
804+
805+
impl RecordBatchStream for PerPartitionStream {
806+
/// Get the schema
807+
fn schema(&self) -> SchemaRef {
808+
self.schema.clone()
809+
}
810+
}
811+
684812
#[cfg(test)]
685813
mod tests {
686814
use super::*;
@@ -705,6 +833,7 @@ mod tests {
705833
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
706834
use futures::FutureExt;
707835
use std::collections::HashSet;
836+
use tokio::task::JoinHandle;
708837

709838
#[tokio::test]
710839
async fn one_to_many_round_robin() -> Result<()> {

datafusion/core/tests/fuzz_cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ mod aggregate_fuzz;
1919
mod join_fuzz;
2020
mod merge_fuzz;
2121
mod order_spill_fuzz;
22+
mod sort_preserving_repartition_fuzz;
2223
mod window_fuzz;

0 commit comments

Comments
 (0)