17
17
18
18
//! Hash aggregation
19
19
20
+ use std:: collections:: VecDeque ;
20
21
use std:: sync:: Arc ;
21
22
use std:: task:: { Context , Poll } ;
22
23
use std:: vec;
@@ -61,7 +62,7 @@ pub(crate) enum ExecutionState {
61
62
ReadingInput ,
62
63
/// When producing output, the remaining rows to output are stored
63
64
/// here and are sliced off as needed in batch_size chunks
64
- ProducingOutput ( RecordBatch ) ,
65
+ ProducingOutput ( VecDeque < RecordBatch > ) ,
65
66
/// Produce intermediate aggregate state for each input row without
66
67
/// aggregation.
67
68
///
@@ -553,7 +554,7 @@ impl Stream for GroupedHashAggregateStream {
553
554
let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
554
555
555
556
loop {
556
- match & self . exec_state {
557
+ match & mut self . exec_state {
557
558
ExecutionState :: ReadingInput => ' reading_input: {
558
559
match ready ! ( self . input. poll_next_unpin( cx) ) {
559
560
// new batch to aggregate
@@ -583,8 +584,9 @@ impl Stream for GroupedHashAggregateStream {
583
584
}
584
585
585
586
if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
586
- let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
587
- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
587
+ let batches = extract_ok ! ( self . emit( to_emit, false ) ) ;
588
+ self . exec_state =
589
+ ExecutionState :: ProducingOutput ( batches) ;
588
590
timer. done ( ) ;
589
591
// make sure the exec_state just set is not overwritten below
590
592
break ' reading_input;
@@ -627,29 +629,20 @@ impl Stream for GroupedHashAggregateStream {
627
629
}
628
630
}
629
631
630
- ExecutionState :: ProducingOutput ( batch) => {
631
- // slice off a part of the batch, if needed
632
- let output_batch;
633
- let size = self . batch_size ;
634
- ( self . exec_state , output_batch) = if batch. num_rows ( ) <= size {
635
- (
636
- if self . input_done {
637
- ExecutionState :: Done
638
- } else if self . should_skip_aggregation ( ) {
639
- ExecutionState :: SkippingAggregation
640
- } else {
641
- ExecutionState :: ReadingInput
642
- } ,
643
- batch. clone ( ) ,
644
- )
645
- } else {
646
- // output first batch_size rows
647
- let size = self . batch_size ;
648
- let num_remaining = batch. num_rows ( ) - size;
649
- let remaining = batch. slice ( size, num_remaining) ;
650
- let output = batch. slice ( 0 , size) ;
651
- ( ExecutionState :: ProducingOutput ( remaining) , output)
652
- } ;
632
+ ExecutionState :: ProducingOutput ( batches) => {
633
+ assert ! ( !batches. is_empty( ) ) ;
634
+ let output_batch = batches. pop_front ( ) . expect ( "RecordBatch" ) ;
635
+
636
+ if batches. is_empty ( ) {
637
+ self . exec_state = if self . input_done {
638
+ ExecutionState :: Done
639
+ } else if self . should_skip_aggregation ( ) {
640
+ ExecutionState :: SkippingAggregation
641
+ } else {
642
+ ExecutionState :: ReadingInput
643
+ } ;
644
+ }
645
+
653
646
return Poll :: Ready ( Some ( Ok (
654
647
output_batch. record_output ( & self . baseline_metrics )
655
648
) ) ) ;
@@ -777,14 +770,55 @@ impl GroupedHashAggregateStream {
777
770
778
771
/// Create an output RecordBatch with the group keys and
779
772
/// accumulator states/values specified in emit_to
780
- fn emit ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < RecordBatch > {
773
+ fn emit ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < VecDeque < RecordBatch > > {
781
774
let schema = if spilling {
782
775
Arc :: clone ( & self . spill_state . spill_schema )
783
776
} else {
784
777
self . schema ( )
785
778
} ;
786
779
if self . group_values . is_empty ( ) {
787
- return Ok ( RecordBatch :: new_empty ( schema) ) ;
780
+ return Ok ( VecDeque :: from ( [ RecordBatch :: new_empty ( schema) ] ) ) ;
781
+ }
782
+
783
+ if matches ! ( emit_to, EmitTo :: All ) && !spilling {
784
+ let outputs = self
785
+ . group_values
786
+ . emit_all_with_batch_size ( self . batch_size ) ?;
787
+
788
+ let mut batches = VecDeque :: with_capacity ( outputs. len ( ) ) ;
789
+ for mut output in outputs {
790
+ let num_rows = output[ 0 ] . len ( ) ;
791
+ // let batch_emit_to = EmitTo::First(num_rows);
792
+ let batch_emit_to = if num_rows == self . batch_size {
793
+ EmitTo :: First ( self . batch_size )
794
+ } else {
795
+ EmitTo :: All
796
+ } ;
797
+
798
+ for acc in self . accumulators . iter_mut ( ) {
799
+ match self . mode {
800
+ AggregateMode :: Partial => {
801
+ output. extend ( acc. state ( batch_emit_to) ?)
802
+ }
803
+ _ if spilling => {
804
+ // If spilling, output partial state because the spilled data will be
805
+ // merged and re-evaluated later.
806
+ output. extend ( acc. state ( batch_emit_to) ?)
807
+ }
808
+ AggregateMode :: Final
809
+ | AggregateMode :: FinalPartitioned
810
+ | AggregateMode :: Single
811
+ | AggregateMode :: SinglePartitioned => {
812
+ output. push ( acc. evaluate ( batch_emit_to) ?)
813
+ }
814
+ }
815
+ }
816
+ let batch = RecordBatch :: try_new ( Arc :: clone ( & schema) , output) ?;
817
+ batches. push_back ( batch) ;
818
+ }
819
+
820
+ let _ = self . update_memory_reservation ( ) ;
821
+ return Ok ( batches) ;
788
822
}
789
823
790
824
let mut output = self . group_values . emit ( emit_to) ?;
@@ -812,7 +846,7 @@ impl GroupedHashAggregateStream {
812
846
// over the target memory size after emission, we can emit again rather than returning Err.
813
847
let _ = self . update_memory_reservation ( ) ;
814
848
let batch = RecordBatch :: try_new ( schema, output) ?;
815
- Ok ( batch)
849
+ Ok ( VecDeque :: from ( [ batch] ) )
816
850
}
817
851
818
852
/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
@@ -838,7 +872,9 @@ impl GroupedHashAggregateStream {
838
872
839
873
/// Emit all rows, sort them, and store them on disk.
840
874
fn spill ( & mut self ) -> Result < ( ) > {
841
- let emit = self . emit ( EmitTo :: All , true ) ?;
875
+ let mut batches = self . emit ( EmitTo :: All , true ) ?;
876
+ assert_eq ! ( batches. len( ) , 1 ) ;
877
+ let emit = batches. pop_front ( ) . expect ( "RecordBatch" ) ;
842
878
let sorted = sort_batch ( & emit, & self . spill_state . spill_expr , None ) ?;
843
879
let spillfile = self . runtime . disk_manager . create_tmp_file ( "HashAggSpill" ) ?;
844
880
let mut writer = IPCWriter :: new ( spillfile. path ( ) , & emit. schema ( ) ) ?;
@@ -881,8 +917,8 @@ impl GroupedHashAggregateStream {
881
917
&& self . update_memory_reservation ( ) . is_err ( )
882
918
{
883
919
let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
884
- let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
885
- self . exec_state = ExecutionState :: ProducingOutput ( batch ) ;
920
+ let batches = self . emit ( EmitTo :: First ( n) , false ) ?;
921
+ self . exec_state = ExecutionState :: ProducingOutput ( batches ) ;
886
922
}
887
923
Ok ( ( ) )
888
924
}
@@ -892,18 +928,22 @@ impl GroupedHashAggregateStream {
892
928
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
893
929
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
894
930
fn update_merged_stream ( & mut self ) -> Result < ( ) > {
895
- let batch = self . emit ( EmitTo :: All , true ) ?;
931
+ let batches = self . emit ( EmitTo :: All , true ) ?;
932
+ assert ! ( !batches. is_empty( ) ) ;
933
+ let schema = batches[ 0 ] . schema ( ) ;
896
934
// clear up memory for streaming_merge
897
935
self . clear_all ( ) ;
898
936
self . update_memory_reservation ( ) ?;
899
937
let mut streams: Vec < SendableRecordBatchStream > = vec ! [ ] ;
900
938
let expr = self . spill_state . spill_expr . clone ( ) ;
901
- let schema = batch. schema ( ) ;
939
+ // TODO No need to collect
940
+ let sorted = batches
941
+ . into_iter ( )
942
+ . map ( |batch| sort_batch ( & batch, & expr, None ) )
943
+ . collect :: < Vec < _ > > ( ) ;
902
944
streams. push ( Box :: pin ( RecordBatchStreamAdapter :: new (
903
945
Arc :: clone ( & schema) ,
904
- futures:: stream:: once ( futures:: future:: lazy ( move |_| {
905
- sort_batch ( & batch, & expr, None )
906
- } ) ) ,
946
+ futures:: stream:: iter ( sorted) ,
907
947
) ) ) ;
908
948
for spill in self . spill_state . spills . drain ( ..) {
909
949
let stream = read_spill_as_stream ( spill, Arc :: clone ( & schema) , 2 ) ?;
@@ -940,8 +980,8 @@ impl GroupedHashAggregateStream {
940
980
let elapsed_compute = self . baseline_metrics . elapsed_compute ( ) . clone ( ) ;
941
981
let timer = elapsed_compute. timer ( ) ;
942
982
self . exec_state = if self . spill_state . spills . is_empty ( ) {
943
- let batch = self . emit ( EmitTo :: All , false ) ?;
944
- ExecutionState :: ProducingOutput ( batch )
983
+ let batches = self . emit ( EmitTo :: All , false ) ?;
984
+ ExecutionState :: ProducingOutput ( batches )
945
985
} else {
946
986
// If spill files exist, stream-merge them.
947
987
self . update_merged_stream ( ) ?;
0 commit comments