Skip to content

Commit 0d334cf

Browse files
richoxzhangli20
andauthored
Use tournament loser tree for k-way sort-merging (apache#4301)
Co-authored-by: zhangli20 <[email protected]>
1 parent 52e198e commit 0d334cf

File tree

2 files changed

+114
-60
lines changed

2 files changed

+114
-60
lines changed

datafusion/core/src/physical_plan/sorts/cursor.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,14 @@ impl PartialOrd for SortKeyCursor {
109109

110110
impl Ord for SortKeyCursor {
111111
fn cmp(&self, other: &Self) -> Ordering {
112-
self.current()
113-
.cmp(&other.current())
114-
.then_with(|| self.stream_idx.cmp(&other.stream_idx))
112+
match (self.is_finished(), other.is_finished()) {
113+
(true, true) => Ordering::Equal,
114+
(_, true) => Ordering::Less,
115+
(true, _) => Ordering::Greater,
116+
_ => self
117+
.current()
118+
.cmp(&other.current())
119+
.then_with(|| self.stream_idx.cmp(&other.stream_idx)),
120+
}
115121
}
116122
}

datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs

Lines changed: 105 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
//! Defines the sort preserving merge plan
1919
2020
use std::any::Any;
21-
use std::cmp::Reverse;
22-
use std::collections::{BinaryHeap, VecDeque};
21+
use std::collections::VecDeque;
2322
use std::pin::Pin;
2423
use std::sync::Arc;
2524
use std::task::{Context, Poll};
@@ -304,10 +303,6 @@ pub(crate) struct SortPreservingMergeStream {
304303
/// their rows have been yielded to the output
305304
batches: Vec<VecDeque<RecordBatch>>,
306305

307-
/// Maintain a flag for each stream denoting if the current cursor
308-
/// has finished and needs to poll from the stream
309-
cursor_finished: Vec<bool>,
310-
311306
/// The accumulated row indexes for the next record batch
312307
in_progress: Vec<RowIndex>,
313308

@@ -323,8 +318,17 @@ pub(crate) struct SortPreservingMergeStream {
323318
/// An id to uniquely identify the input stream batch
324319
next_batch_id: usize,
325320

326-
/// Heap that yields [`SortKeyCursor`] in increasing order
327-
heap: BinaryHeap<Reverse<SortKeyCursor>>,
321+
/// Vector that holds all [`SortKeyCursor`]s
322+
cursors: Vec<Option<SortKeyCursor>>,
323+
324+
/// The loser tree that always produces the minimum cursor
325+
///
326+
/// Node 0 stores the top winner, Nodes 1..num_streams store
327+
/// the loser nodes
328+
loser_tree: Vec<usize>,
329+
330+
/// Identify whether the loser tree is adjusted
331+
loser_tree_adjusted: bool,
328332

329333
/// target batch size
330334
batch_size: usize,
@@ -361,14 +365,15 @@ impl SortPreservingMergeStream {
361365
Ok(Self {
362366
schema,
363367
batches,
364-
cursor_finished: vec![true; stream_count],
365368
streams: MergingStreams::new(wrappers),
366369
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
367370
tracking_metrics,
368371
aborted: false,
369372
in_progress: vec![],
370373
next_batch_id: 0,
371-
heap: BinaryHeap::with_capacity(stream_count),
374+
cursors: (0..stream_count).into_iter().map(|_| None).collect(),
375+
loser_tree: Vec::with_capacity(stream_count),
376+
loser_tree_adjusted: false,
372377
batch_size,
373378
row_converter,
374379
})
@@ -382,7 +387,11 @@ impl SortPreservingMergeStream {
382387
cx: &mut Context<'_>,
383388
idx: usize,
384389
) -> Poll<ArrowResult<()>> {
385-
if !self.cursor_finished[idx] {
390+
if self.cursors[idx]
391+
.as_ref()
392+
.map(|cursor| !cursor.is_finished())
393+
.unwrap_or(false)
394+
{
386395
// Cursor is not finished - don't need a new RecordBatch yet
387396
return Poll::Ready(Ok(()));
388397
}
@@ -418,14 +427,12 @@ impl SortPreservingMergeStream {
418427
}
419428
};
420429

421-
let cursor = SortKeyCursor::new(
430+
self.cursors[idx] = Some(SortKeyCursor::new(
422431
idx,
423432
self.next_batch_id, // assign this batch an ID
424433
rows,
425-
);
434+
));
426435
self.next_batch_id += 1;
427-
self.heap.push(Reverse(cursor));
428-
self.cursor_finished[idx] = false;
429436
self.batches[idx].push_back(batch)
430437
} else {
431438
empty_batch = true;
@@ -551,17 +558,46 @@ impl SortPreservingMergeStream {
551558
if self.aborted {
552559
return Poll::Ready(None);
553560
}
561+
let num_streams = self.streams.num_streams();
562+
563+
// Init all cursors and the loser tree in the first poll
564+
if self.loser_tree.is_empty() {
565+
// Ensure all non-exhausted streams have a cursor from which
566+
// rows can be pulled
567+
for i in 0..num_streams {
568+
match futures::ready!(self.maybe_poll_stream(cx, i)) {
569+
Ok(_) => {}
570+
Err(e) => {
571+
self.aborted = true;
572+
return Poll::Ready(Some(Err(e)));
573+
}
574+
}
575+
}
554576

555-
// Ensure all non-exhausted streams have a cursor from which
556-
// rows can be pulled
557-
for i in 0..self.streams.num_streams() {
558-
match futures::ready!(self.maybe_poll_stream(cx, i)) {
559-
Ok(_) => {}
560-
Err(e) => {
561-
self.aborted = true;
562-
return Poll::Ready(Some(Err(e)));
577+
// Init loser tree
578+
self.loser_tree.resize(num_streams, usize::MAX);
579+
for i in 0..num_streams {
580+
let mut winner = i;
581+
let mut cmp_node = (num_streams + i) / 2;
582+
while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
583+
let challenger = self.loser_tree[cmp_node];
584+
let challenger_win =
585+
match (&self.cursors[winner], &self.cursors[challenger]) {
586+
(None, _) => true,
587+
(_, None) => false,
588+
(Some(winner), Some(challenger)) => challenger < winner,
589+
};
590+
if challenger_win {
591+
self.loser_tree[cmp_node] = winner;
592+
winner = challenger;
593+
} else {
594+
self.loser_tree[cmp_node] = challenger;
595+
}
596+
cmp_node /= 2;
563597
}
598+
self.loser_tree[cmp_node] = winner;
564599
}
600+
self.loser_tree_adjusted = true;
565601
}
566602

567603
// NB timer records time taken on drop, so there are no
@@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
570606
let _timer = elapsed_compute.timer();
571607

572608
loop {
573-
match self.heap.pop() {
574-
Some(Reverse(mut cursor)) => {
575-
let stream_idx = cursor.stream_idx();
576-
let batch_idx = self.batches[stream_idx].len() - 1;
577-
let row_idx = cursor.advance();
578-
579-
let mut cursor_finished = false;
580-
// insert the cursor back to heap if the record batch is not exhausted
581-
if !cursor.is_finished() {
582-
self.heap.push(Reverse(cursor));
583-
} else {
584-
cursor_finished = true;
585-
self.cursor_finished[stream_idx] = true;
609+
// Adjust the loser tree if necessary
610+
if !self.loser_tree_adjusted {
611+
let mut winner = self.loser_tree[0];
612+
match futures::ready!(self.maybe_poll_stream(cx, winner)) {
613+
Ok(_) => {}
614+
Err(e) => {
615+
self.aborted = true;
616+
return Poll::Ready(Some(Err(e)));
586617
}
618+
}
587619

588-
self.in_progress.push(RowIndex {
589-
stream_idx,
590-
batch_idx,
591-
row_idx,
592-
});
593-
594-
if self.in_progress.len() == self.batch_size {
595-
return Poll::Ready(Some(self.build_record_batch()));
620+
let mut cmp_node = (num_streams + winner) / 2;
621+
while cmp_node != 0 {
622+
let challenger = self.loser_tree[cmp_node];
623+
let challenger_win =
624+
match (&self.cursors[winner], &self.cursors[challenger]) {
625+
(None, _) => true,
626+
(_, None) => false,
627+
(Some(winner), Some(challenger)) => challenger < winner,
628+
};
629+
if challenger_win {
630+
self.loser_tree[cmp_node] = winner;
631+
winner = challenger;
596632
}
633+
cmp_node /= 2;
634+
}
635+
self.loser_tree[0] = winner;
636+
self.loser_tree_adjusted = true;
637+
}
597638

598-
// If removed the last row from the cursor, need to fetch a new record
599-
// batch if possible, before looping round again
600-
if cursor_finished {
601-
match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) {
602-
Ok(_) => {}
603-
Err(e) => {
604-
self.aborted = true;
605-
return Poll::Ready(Some(Err(e)));
606-
}
607-
}
608-
}
639+
let min_cursor_idx = self.loser_tree[0];
640+
let next = self.cursors[min_cursor_idx]
641+
.as_mut()
642+
.filter(|cursor| !cursor.is_finished())
643+
.map(|cursor| (cursor.stream_idx(), cursor.advance()));
644+
645+
if let Some((stream_idx, row_idx)) = next {
646+
self.loser_tree_adjusted = false;
647+
let batch_idx = self.batches[stream_idx].len() - 1;
648+
self.in_progress.push(RowIndex {
649+
stream_idx,
650+
batch_idx,
651+
row_idx,
652+
});
653+
if self.in_progress.len() == self.batch_size {
654+
return Poll::Ready(Some(self.build_record_batch()));
609655
}
610-
None if self.in_progress.is_empty() => return Poll::Ready(None),
611-
None => return Poll::Ready(Some(self.build_record_batch())),
656+
} else if !self.in_progress.is_empty() {
657+
return Poll::Ready(Some(self.build_record_batch()));
658+
} else {
659+
return Poll::Ready(None);
612660
}
613661
}
614662
}

0 commit comments

Comments
 (0)