1717
1818//! Aggregate without grouping columns
1919
20+ use crate :: execution:: context:: TaskContext ;
21+ use crate :: execution:: memory_manager:: proxy:: MemoryConsumerProxy ;
22+ use crate :: execution:: MemoryConsumerId ;
2023use crate :: physical_plan:: aggregates:: {
2124 aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem ,
2225 AggregateMode ,
@@ -28,22 +31,31 @@ use arrow::error::{ArrowError, Result as ArrowResult};
2831use arrow:: record_batch:: RecordBatch ;
2932use datafusion_common:: Result ;
3033use datafusion_physical_expr:: { AggregateExpr , PhysicalExpr } ;
34+ use futures:: stream:: BoxStream ;
3135use std:: sync:: Arc ;
3236use std:: task:: { Context , Poll } ;
3337
34- use futures:: {
35- ready,
36- stream:: { Stream , StreamExt } ,
37- } ;
38+ use futures:: stream:: { Stream , StreamExt } ;
3839
3940/// stream struct for aggregation without grouping columns
4041pub ( crate ) struct AggregateStream {
42+ stream : BoxStream < ' static , ArrowResult < RecordBatch > > ,
43+ schema : SchemaRef ,
44+ }
45+
46+ /// Actual implementation of [`AggregateStream`].
47+ ///
48+ /// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
49+ /// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
50+ /// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`].
51+ struct AggregateStreamInner {
4152 schema : SchemaRef ,
4253 mode : AggregateMode ,
4354 input : SendableRecordBatchStream ,
4455 baseline_metrics : BaselineMetrics ,
4556 aggregate_expressions : Vec < Vec < Arc < dyn PhysicalExpr > > > ,
4657 accumulators : Vec < AccumulatorItem > ,
58+ memory_consumer : MemoryConsumerProxy ,
4759 finished : bool ,
4860}
4961
@@ -55,19 +67,87 @@ impl AggregateStream {
5567 aggr_expr : Vec < Arc < dyn AggregateExpr > > ,
5668 input : SendableRecordBatchStream ,
5769 baseline_metrics : BaselineMetrics ,
70+ context : Arc < TaskContext > ,
71+ partition : usize ,
5872 ) -> datafusion_common:: Result < Self > {
5973 let aggregate_expressions = aggregate_expressions ( & aggr_expr, & mode, 0 ) ?;
6074 let accumulators = create_accumulators ( & aggr_expr) ?;
61-
62- Ok ( Self {
63- schema,
75+ let memory_consumer = MemoryConsumerProxy :: new (
76+ "AggregationState" ,
77+ MemoryConsumerId :: new ( partition) ,
78+ Arc :: clone ( & context. runtime_env ( ) . memory_manager ) ,
79+ ) ;
80+
81+ let inner = AggregateStreamInner {
82+ schema : Arc :: clone ( & schema) ,
6483 mode,
6584 input,
6685 baseline_metrics,
6786 aggregate_expressions,
6887 accumulators,
88+ memory_consumer,
6989 finished : false ,
70- } )
90+ } ;
91+ let stream = futures:: stream:: unfold ( inner, |mut this| async move {
92+ if this. finished {
93+ return None ;
94+ }
95+
96+ let elapsed_compute = this. baseline_metrics . elapsed_compute ( ) ;
97+
98+ loop {
99+ let result = match this. input . next ( ) . await {
100+ Some ( Ok ( batch) ) => {
101+ let timer = elapsed_compute. timer ( ) ;
102+ let result = aggregate_batch (
103+ & this. mode ,
104+ & batch,
105+ & mut this. accumulators ,
106+ & this. aggregate_expressions ,
107+ ) ;
108+
109+ timer. done ( ) ;
110+
111+ // allocate memory
112+ // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
113+ // overshooting a bit. Also this means we either store the whole record batch or not.
114+ let result = match result {
115+ Ok ( allocated) => this. memory_consumer . alloc ( allocated) . await ,
116+ Err ( e) => Err ( e) ,
117+ } ;
118+
119+ match result {
120+ Ok ( _) => continue ,
121+ Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ,
122+ }
123+ }
124+ Some ( Err ( e) ) => Err ( e) ,
125+ None => {
126+ this. finished = true ;
127+ let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
128+ let result = finalize_aggregation ( & this. accumulators , & this. mode )
129+ . map_err ( |e| ArrowError :: ExternalError ( Box :: new ( e) ) )
130+ . and_then ( |columns| {
131+ RecordBatch :: try_new ( this. schema . clone ( ) , columns)
132+ } )
133+ . record_output ( & this. baseline_metrics ) ;
134+
135+ timer. done ( ) ;
136+
137+ result
138+ }
139+ } ;
140+
141+ this. finished = true ;
142+ return Some ( ( result, this) ) ;
143+ }
144+ } ) ;
145+
146+ // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
147+ let stream = stream. fuse ( ) ;
148+ let stream = Box :: pin ( stream) ;
149+
150+ Ok ( Self { schema, stream } )
71151 }
72152}
73153
@@ -79,49 +159,7 @@ impl Stream for AggregateStream {
79159 cx : & mut Context < ' _ > ,
80160 ) -> Poll < Option < Self :: Item > > {
81161 let this = & mut * self ;
82- if this. finished {
83- return Poll :: Ready ( None ) ;
84- }
85-
86- let elapsed_compute = this. baseline_metrics . elapsed_compute ( ) ;
87-
88- loop {
89- let result = match ready ! ( this. input. poll_next_unpin( cx) ) {
90- Some ( Ok ( batch) ) => {
91- let timer = elapsed_compute. timer ( ) ;
92- let result = aggregate_batch (
93- & this. mode ,
94- & batch,
95- & mut this. accumulators ,
96- & this. aggregate_expressions ,
97- ) ;
98-
99- timer. done ( ) ;
100-
101- match result {
102- Ok ( _) => continue ,
103- Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) ,
104- }
105- }
106- Some ( Err ( e) ) => Err ( e) ,
107- None => {
108- this. finished = true ;
109- let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
110- let result = finalize_aggregation ( & this. accumulators , & this. mode )
111- . map_err ( |e| ArrowError :: ExternalError ( Box :: new ( e) ) )
112- . and_then ( |columns| {
113- RecordBatch :: try_new ( this. schema . clone ( ) , columns)
114- } )
115- . record_output ( & this. baseline_metrics ) ;
116-
117- timer. done ( ) ;
118- result
119- }
120- } ;
121-
122- this. finished = true ;
123- return Poll :: Ready ( Some ( result) ) ;
124- }
162+ this. stream . poll_next_unpin ( cx)
125163 }
126164}
127165
@@ -131,13 +169,19 @@ impl RecordBatchStream for AggregateStream {
131169 }
132170}
133171
172+ /// Perform group-by aggregation for the given [`RecordBatch`].
173+ ///
174+ /// If successfull, this returns the additional number of bytes that were allocated during this process.
175+ ///
134176/// TODO: Make this a member function
135177fn aggregate_batch (
136178 mode : & AggregateMode ,
137179 batch : & RecordBatch ,
138180 accumulators : & mut [ AccumulatorItem ] ,
139181 expressions : & [ Vec < Arc < dyn PhysicalExpr > > ] ,
140- ) -> Result < ( ) > {
182+ ) -> Result < usize > {
183+ let mut allocated = 0usize ;
184+
141185 // 1.1 iterate accumulators and respective expressions together
142186 // 1.2 evaluate expressions
143187 // 1.3 update / merge accumulators with the expressions' values
@@ -155,11 +199,17 @@ fn aggregate_batch(
155199 . collect :: < Result < Vec < _ > > > ( ) ?;
156200
157201 // 1.3
158- match mode {
202+ let size_pre = accum. size ( ) ;
203+ let res = match mode {
159204 AggregateMode :: Partial => accum. update_batch ( values) ,
160205 AggregateMode :: Final | AggregateMode :: FinalPartitioned => {
161206 accum. merge_batch ( values)
162207 }
163- }
164- } )
208+ } ;
209+ let size_post = accum. size ( ) ;
210+ allocated += size_post. saturating_sub ( size_pre) ;
211+ res
212+ } ) ?;
213+
214+ Ok ( allocated)
165215}
0 commit comments