1818//! Defines the sort preserving merge plan
1919
2020use std:: any:: Any ;
21- use std:: cmp:: Reverse ;
22- use std:: collections:: { BinaryHeap , VecDeque } ;
21+ use std:: collections:: VecDeque ;
2322use std:: pin:: Pin ;
2423use std:: sync:: Arc ;
2524use std:: task:: { Context , Poll } ;
@@ -304,10 +303,6 @@ pub(crate) struct SortPreservingMergeStream {
304303 /// their rows have been yielded to the output
305304 batches : Vec < VecDeque < RecordBatch > > ,
306305
307- /// Maintain a flag for each stream denoting if the current cursor
308- /// has finished and needs to poll from the stream
309- cursor_finished : Vec < bool > ,
310-
311306 /// The accumulated row indexes for the next record batch
312307 in_progress : Vec < RowIndex > ,
313308
@@ -323,8 +318,17 @@ pub(crate) struct SortPreservingMergeStream {
323318 /// An id to uniquely identify the input stream batch
324319 next_batch_id : usize ,
325320
326- /// Heap that yields [`SortKeyCursor`] in increasing order
327- heap : BinaryHeap < Reverse < SortKeyCursor > > ,
321+ /// Vector that holds all [`SortKeyCursor`]s
322+ cursors : Vec < Option < SortKeyCursor > > ,
323+
324+ /// The loser tree that always produces the minimum cursor
325+ ///
326+ /// Node 0 stores the top winner, Nodes 1..num_streams store
327+ /// the loser nodes
328+ loser_tree : Vec < usize > ,
329+
330+ /// Identify whether the loser tree is adjusted
331+ loser_tree_adjusted : bool ,
328332
329333 /// target batch size
330334 batch_size : usize ,
@@ -361,14 +365,15 @@ impl SortPreservingMergeStream {
361365 Ok ( Self {
362366 schema,
363367 batches,
364- cursor_finished : vec ! [ true ; stream_count] ,
365368 streams : MergingStreams :: new ( wrappers) ,
366369 column_expressions : expressions. iter ( ) . map ( |x| x. expr . clone ( ) ) . collect ( ) ,
367370 tracking_metrics,
368371 aborted : false ,
369372 in_progress : vec ! [ ] ,
370373 next_batch_id : 0 ,
371- heap : BinaryHeap :: with_capacity ( stream_count) ,
374+ cursors : ( 0 ..stream_count) . into_iter ( ) . map ( |_| None ) . collect ( ) ,
375+ loser_tree : Vec :: with_capacity ( stream_count) ,
376+ loser_tree_adjusted : false ,
372377 batch_size,
373378 row_converter,
374379 } )
@@ -382,7 +387,11 @@ impl SortPreservingMergeStream {
382387 cx : & mut Context < ' _ > ,
383388 idx : usize ,
384389 ) -> Poll < ArrowResult < ( ) > > {
385- if !self . cursor_finished [ idx] {
390+ if self . cursors [ idx]
391+ . as_ref ( )
392+ . map ( |cursor| !cursor. is_finished ( ) )
393+ . unwrap_or ( false )
394+ {
386395 // Cursor is not finished - don't need a new RecordBatch yet
387396 return Poll :: Ready ( Ok ( ( ) ) ) ;
388397 }
@@ -418,14 +427,12 @@ impl SortPreservingMergeStream {
418427 }
419428 } ;
420429
421- let cursor = SortKeyCursor :: new (
430+ self . cursors [ idx ] = Some ( SortKeyCursor :: new (
422431 idx,
423432 self . next_batch_id , // assign this batch an ID
424433 rows,
425- ) ;
434+ ) ) ;
426435 self . next_batch_id += 1 ;
427- self . heap . push ( Reverse ( cursor) ) ;
428- self . cursor_finished [ idx] = false ;
429436 self . batches [ idx] . push_back ( batch)
430437 } else {
431438 empty_batch = true ;
@@ -551,17 +558,46 @@ impl SortPreservingMergeStream {
551558 if self . aborted {
552559 return Poll :: Ready ( None ) ;
553560 }
561+ let num_streams = self . streams . num_streams ( ) ;
562+
563+ // Init all cursors and the loser tree in the first poll
564+ if self . loser_tree . is_empty ( ) {
565+ // Ensure all non-exhausted streams have a cursor from which
566+ // rows can be pulled
567+ for i in 0 ..num_streams {
568+ match futures:: ready!( self . maybe_poll_stream( cx, i) ) {
569+ Ok ( _) => { }
570+ Err ( e) => {
571+ self . aborted = true ;
572+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
573+ }
574+ }
575+ }
554576
555- // Ensure all non-exhausted streams have a cursor from which
556- // rows can be pulled
557- for i in 0 ..self . streams . num_streams ( ) {
558- match futures:: ready!( self . maybe_poll_stream( cx, i) ) {
559- Ok ( _) => { }
560- Err ( e) => {
561- self . aborted = true ;
562- return Poll :: Ready ( Some ( Err ( e) ) ) ;
577+ // Init loser tree
578+ self . loser_tree . resize ( num_streams, usize:: MAX ) ;
579+ for i in 0 ..num_streams {
580+ let mut winner = i;
581+ let mut cmp_node = ( num_streams + i) / 2 ;
582+ while cmp_node != 0 && self . loser_tree [ cmp_node] != usize:: MAX {
583+ let challenger = self . loser_tree [ cmp_node] ;
584+ let challenger_win =
585+ match ( & self . cursors [ winner] , & self . cursors [ challenger] ) {
586+ ( None , _) => true ,
587+ ( _, None ) => false ,
588+ ( Some ( winner) , Some ( challenger) ) => challenger < winner,
589+ } ;
590+ if challenger_win {
591+ self . loser_tree [ cmp_node] = winner;
592+ winner = challenger;
593+ } else {
594+ self . loser_tree [ cmp_node] = challenger;
595+ }
596+ cmp_node /= 2 ;
563597 }
598+ self . loser_tree [ cmp_node] = winner;
564599 }
600+ self . loser_tree_adjusted = true ;
565601 }
566602
567603 // NB timer records time taken on drop, so there are no
@@ -570,45 +606,57 @@ impl SortPreservingMergeStream {
570606 let _timer = elapsed_compute. timer ( ) ;
571607
572608 loop {
573- match self . heap . pop ( ) {
574- Some ( Reverse ( mut cursor) ) => {
575- let stream_idx = cursor. stream_idx ( ) ;
576- let batch_idx = self . batches [ stream_idx] . len ( ) - 1 ;
577- let row_idx = cursor. advance ( ) ;
578-
579- let mut cursor_finished = false ;
580- // insert the cursor back to heap if the record batch is not exhausted
581- if !cursor. is_finished ( ) {
582- self . heap . push ( Reverse ( cursor) ) ;
583- } else {
584- cursor_finished = true ;
585- self . cursor_finished [ stream_idx] = true ;
609+ // Adjust the loser tree if necessary
610+ if !self . loser_tree_adjusted {
611+ let mut winner = self . loser_tree [ 0 ] ;
612+ match futures:: ready!( self . maybe_poll_stream( cx, winner) ) {
613+ Ok ( _) => { }
614+ Err ( e) => {
615+ self . aborted = true ;
616+ return Poll :: Ready ( Some ( Err ( e) ) ) ;
586617 }
618+ }
587619
588- self . in_progress . push ( RowIndex {
589- stream_idx,
590- batch_idx,
591- row_idx,
592- } ) ;
593-
594- if self . in_progress . len ( ) == self . batch_size {
595- return Poll :: Ready ( Some ( self . build_record_batch ( ) ) ) ;
620+ let mut cmp_node = ( num_streams + winner) / 2 ;
621+ while cmp_node != 0 {
622+ let challenger = self . loser_tree [ cmp_node] ;
623+ let challenger_win =
624+ match ( & self . cursors [ winner] , & self . cursors [ challenger] ) {
625+ ( None , _) => true ,
626+ ( _, None ) => false ,
627+ ( Some ( winner) , Some ( challenger) ) => challenger < winner,
628+ } ;
629+ if challenger_win {
630+ self . loser_tree [ cmp_node] = winner;
631+ winner = challenger;
596632 }
633+ cmp_node /= 2 ;
634+ }
635+ self . loser_tree [ 0 ] = winner;
636+ self . loser_tree_adjusted = true ;
637+ }
597638
598- // If removed the last row from the cursor, need to fetch a new record
599- // batch if possible, before looping round again
600- if cursor_finished {
601- match futures:: ready!( self . maybe_poll_stream( cx, stream_idx) ) {
602- Ok ( _) => { }
603- Err ( e) => {
604- self . aborted = true ;
605- return Poll :: Ready ( Some ( Err ( e) ) ) ;
606- }
607- }
608- }
639+ let min_cursor_idx = self . loser_tree [ 0 ] ;
640+ let next = self . cursors [ min_cursor_idx]
641+ . as_mut ( )
642+ . filter ( |cursor| !cursor. is_finished ( ) )
643+ . map ( |cursor| ( cursor. stream_idx ( ) , cursor. advance ( ) ) ) ;
644+
645+ if let Some ( ( stream_idx, row_idx) ) = next {
646+ self . loser_tree_adjusted = false ;
647+ let batch_idx = self . batches [ stream_idx] . len ( ) - 1 ;
648+ self . in_progress . push ( RowIndex {
649+ stream_idx,
650+ batch_idx,
651+ row_idx,
652+ } ) ;
653+ if self . in_progress . len ( ) == self . batch_size {
654+ return Poll :: Ready ( Some ( self . build_record_batch ( ) ) ) ;
609655 }
610- None if self . in_progress . is_empty ( ) => return Poll :: Ready ( None ) ,
611- None => return Poll :: Ready ( Some ( self . build_record_batch ( ) ) ) ,
656+ } else if !self . in_progress . is_empty ( ) {
657+ return Poll :: Ready ( Some ( self . build_record_batch ( ) ) ) ;
658+ } else {
659+ return Poll :: Ready ( None ) ;
612660 }
613661 }
614662 }
0 commit comments