Skip to content

Commit dd3f72a

Browse files
authored
feat: ResourceExhausted for memory limit in AggregateStream (apache#4405)
Closes apache#3940.
1 parent 0d334cf commit dd3f72a

File tree

2 files changed

+126
-59
lines changed

2 files changed

+126
-59
lines changed

datafusion/core/src/physical_plan/aggregates/mod.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ impl AggregateExec {
295295
self.aggr_expr.clone(),
296296
input,
297297
baseline_metrics,
298+
context,
299+
partition,
298300
)?))
299301
} else if self.row_aggregate_supported() {
300302
Ok(StreamType::GroupedHashAggregateStreamV2(
@@ -737,7 +739,7 @@ mod tests {
737739
use arrow::error::{ArrowError, Result as ArrowResult};
738740
use arrow::record_batch::RecordBatch;
739741
use datafusion_common::{DataFusionError, Result, ScalarValue};
740-
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count};
742+
use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count, Median};
741743
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr};
742744
use futures::{FutureExt, Stream};
743745
use std::any::Any;
@@ -1131,12 +1133,20 @@ mod tests {
11311133
);
11321134
let task_ctx = session_ctx.task_ctx();
11331135

1134-
let groups = PhysicalGroupBy {
1136+
let groups_none = PhysicalGroupBy::default();
1137+
let groups_some = PhysicalGroupBy {
11351138
expr: vec![(col("a", &input_schema)?, "a".to_string())],
11361139
null_expr: vec![],
11371140
groups: vec![vec![false]],
11381141
};
11391142

1143+
// something that allocates within the aggregator
1144+
let aggregates_v0: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Median::new(
1145+
col("a", &input_schema)?,
1146+
"MEDIAN(a)".to_string(),
1147+
DataType::UInt32,
1148+
))];
1149+
11401150
// use slow-path in `hash.rs`
11411151
let aggregates_v1: Vec<Arc<dyn AggregateExpr>> =
11421152
vec![Arc::new(ApproxDistinct::new(
@@ -1152,10 +1162,14 @@ mod tests {
11521162
DataType::Float64,
11531163
))];
11541164

1155-
for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] {
1165+
for (version, groups, aggregates) in [
1166+
(0, groups_none, aggregates_v0),
1167+
(1, groups_some.clone(), aggregates_v1),
1168+
(2, groups_some, aggregates_v2),
1169+
] {
11561170
let partial_aggregate = Arc::new(AggregateExec::try_new(
11571171
AggregateMode::Partial,
1158-
groups.clone(),
1172+
groups,
11591173
aggregates,
11601174
input.clone(),
11611175
input_schema.clone(),
@@ -1165,6 +1179,9 @@ mod tests {
11651179

11661180
// ensure that we really got the version we wanted
11671181
match version {
1182+
0 => {
1183+
assert!(matches!(stream, StreamType::AggregateStream(_)));
1184+
}
11681185
1 => {
11691186
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
11701187
}

datafusion/core/src/physical_plan/aggregates/no_grouping.rs

Lines changed: 105 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
//! Aggregate without grouping columns
1919
20+
use crate::execution::context::TaskContext;
21+
use crate::execution::memory_manager::proxy::MemoryConsumerProxy;
22+
use crate::execution::MemoryConsumerId;
2023
use crate::physical_plan::aggregates::{
2124
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
2225
AggregateMode,
@@ -28,22 +31,31 @@ use arrow::error::{ArrowError, Result as ArrowResult};
2831
use arrow::record_batch::RecordBatch;
2932
use datafusion_common::Result;
3033
use datafusion_physical_expr::{AggregateExpr, PhysicalExpr};
34+
use futures::stream::BoxStream;
3135
use std::sync::Arc;
3236
use std::task::{Context, Poll};
3337

34-
use futures::{
35-
ready,
36-
stream::{Stream, StreamExt},
37-
};
38+
use futures::stream::{Stream, StreamExt};
3839

3940
/// stream struct for aggregation without grouping columns
4041
pub(crate) struct AggregateStream {
42+
stream: BoxStream<'static, ArrowResult<RecordBatch>>,
43+
schema: SchemaRef,
44+
}
45+
46+
/// Actual implementation of [`AggregateStream`].
47+
///
48+
/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
49+
/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
50+
/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`].
51+
struct AggregateStreamInner {
4152
schema: SchemaRef,
4253
mode: AggregateMode,
4354
input: SendableRecordBatchStream,
4455
baseline_metrics: BaselineMetrics,
4556
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
4657
accumulators: Vec<AccumulatorItem>,
58+
memory_consumer: MemoryConsumerProxy,
4759
finished: bool,
4860
}
4961

@@ -55,19 +67,87 @@ impl AggregateStream {
5567
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
5668
input: SendableRecordBatchStream,
5769
baseline_metrics: BaselineMetrics,
70+
context: Arc<TaskContext>,
71+
partition: usize,
5872
) -> datafusion_common::Result<Self> {
5973
let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode, 0)?;
6074
let accumulators = create_accumulators(&aggr_expr)?;
61-
62-
Ok(Self {
63-
schema,
75+
let memory_consumer = MemoryConsumerProxy::new(
76+
"AggregationState",
77+
MemoryConsumerId::new(partition),
78+
Arc::clone(&context.runtime_env().memory_manager),
79+
);
80+
81+
let inner = AggregateStreamInner {
82+
schema: Arc::clone(&schema),
6483
mode,
6584
input,
6685
baseline_metrics,
6786
aggregate_expressions,
6887
accumulators,
88+
memory_consumer,
6989
finished: false,
70-
})
90+
};
91+
let stream = futures::stream::unfold(inner, |mut this| async move {
92+
if this.finished {
93+
return None;
94+
}
95+
96+
let elapsed_compute = this.baseline_metrics.elapsed_compute();
97+
98+
loop {
99+
let result = match this.input.next().await {
100+
Some(Ok(batch)) => {
101+
let timer = elapsed_compute.timer();
102+
let result = aggregate_batch(
103+
&this.mode,
104+
&batch,
105+
&mut this.accumulators,
106+
&this.aggregate_expressions,
107+
);
108+
109+
timer.done();
110+
111+
// allocate memory
112+
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
113+
// overshooting a bit. Also this means we either store the whole record batch or not.
114+
let result = match result {
115+
Ok(allocated) => this.memory_consumer.alloc(allocated).await,
116+
Err(e) => Err(e),
117+
};
118+
119+
match result {
120+
Ok(_) => continue,
121+
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
122+
}
123+
}
124+
Some(Err(e)) => Err(e),
125+
None => {
126+
this.finished = true;
127+
let timer = this.baseline_metrics.elapsed_compute().timer();
128+
let result = finalize_aggregation(&this.accumulators, &this.mode)
129+
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
130+
.and_then(|columns| {
131+
RecordBatch::try_new(this.schema.clone(), columns)
132+
})
133+
.record_output(&this.baseline_metrics);
134+
135+
timer.done();
136+
137+
result
138+
}
139+
};
140+
141+
this.finished = true;
142+
return Some((result, this));
143+
}
144+
});
145+
146+
// seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
147+
let stream = stream.fuse();
148+
let stream = Box::pin(stream);
149+
150+
Ok(Self { schema, stream })
71151
}
72152
}
73153

@@ -79,49 +159,7 @@ impl Stream for AggregateStream {
79159
cx: &mut Context<'_>,
80160
) -> Poll<Option<Self::Item>> {
81161
let this = &mut *self;
82-
if this.finished {
83-
return Poll::Ready(None);
84-
}
85-
86-
let elapsed_compute = this.baseline_metrics.elapsed_compute();
87-
88-
loop {
89-
let result = match ready!(this.input.poll_next_unpin(cx)) {
90-
Some(Ok(batch)) => {
91-
let timer = elapsed_compute.timer();
92-
let result = aggregate_batch(
93-
&this.mode,
94-
&batch,
95-
&mut this.accumulators,
96-
&this.aggregate_expressions,
97-
);
98-
99-
timer.done();
100-
101-
match result {
102-
Ok(_) => continue,
103-
Err(e) => Err(ArrowError::ExternalError(Box::new(e))),
104-
}
105-
}
106-
Some(Err(e)) => Err(e),
107-
None => {
108-
this.finished = true;
109-
let timer = this.baseline_metrics.elapsed_compute().timer();
110-
let result = finalize_aggregation(&this.accumulators, &this.mode)
111-
.map_err(|e| ArrowError::ExternalError(Box::new(e)))
112-
.and_then(|columns| {
113-
RecordBatch::try_new(this.schema.clone(), columns)
114-
})
115-
.record_output(&this.baseline_metrics);
116-
117-
timer.done();
118-
result
119-
}
120-
};
121-
122-
this.finished = true;
123-
return Poll::Ready(Some(result));
124-
}
162+
this.stream.poll_next_unpin(cx)
125163
}
126164
}
127165

@@ -131,13 +169,19 @@ impl RecordBatchStream for AggregateStream {
131169
}
132170
}
133171

172+
/// Perform group-by aggregation for the given [`RecordBatch`].
173+
///
174+
/// If successfull, this returns the additional number of bytes that were allocated during this process.
175+
///
134176
/// TODO: Make this a member function
135177
fn aggregate_batch(
136178
mode: &AggregateMode,
137179
batch: &RecordBatch,
138180
accumulators: &mut [AccumulatorItem],
139181
expressions: &[Vec<Arc<dyn PhysicalExpr>>],
140-
) -> Result<()> {
182+
) -> Result<usize> {
183+
let mut allocated = 0usize;
184+
141185
// 1.1 iterate accumulators and respective expressions together
142186
// 1.2 evaluate expressions
143187
// 1.3 update / merge accumulators with the expressions' values
@@ -155,11 +199,17 @@ fn aggregate_batch(
155199
.collect::<Result<Vec<_>>>()?;
156200

157201
// 1.3
158-
match mode {
202+
let size_pre = accum.size();
203+
let res = match mode {
159204
AggregateMode::Partial => accum.update_batch(values),
160205
AggregateMode::Final | AggregateMode::FinalPartitioned => {
161206
accum.merge_batch(values)
162207
}
163-
}
164-
})
208+
};
209+
let size_post = accum.size();
210+
allocated += size_post.saturating_sub(size_pre);
211+
res
212+
})?;
213+
214+
Ok(allocated)
165215
}

0 commit comments

Comments
 (0)