Skip to content

Commit e693ed7

Browse files
AggregateExec: Take grouping sets into account for InputOrderMode (#11301)
* AggregateExec: Take grouping sets into account for InputOrderMode * pr comments
1 parent 08c5345 commit e693ed7

File tree

1 file changed

+113
-8
lines changed
  • datafusion/physical-plan/src/aggregates

1 file changed

+113
-8
lines changed

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 113 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -369,14 +369,26 @@ impl AggregateExec {
369369
new_requirement.extend(req);
370370
new_requirement = collapse_lex_req(new_requirement);
371371

372-
let input_order_mode =
373-
if indices.len() == groupby_exprs.len() && !indices.is_empty() {
374-
InputOrderMode::Sorted
375-
} else if !indices.is_empty() {
376-
InputOrderMode::PartiallySorted(indices)
377-
} else {
378-
InputOrderMode::Linear
379-
};
372+
// If our aggregation has grouping sets then our base grouping exprs will
373+
// be expanded based on the flags in `group_by.groups` where for each
374+
// group we swap the grouping expr for `null` if the flag is `true`
375+
// That means that each index in `indices` is valid if and only if
376+
// it is not null in every group
377+
let indices: Vec<usize> = indices
378+
.into_iter()
379+
.filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
380+
.collect();
381+
382+
let input_order_mode = if indices.len() == groupby_exprs.len()
383+
&& !indices.is_empty()
384+
&& group_by.groups.len() == 1
385+
{
386+
InputOrderMode::Sorted
387+
} else if !indices.is_empty() {
388+
InputOrderMode::PartiallySorted(indices)
389+
} else {
390+
InputOrderMode::Linear
391+
};
380392

381393
// construct a map from the input expression to the output expression of the Aggregation group by
382394
let projection_mapping =
@@ -1180,6 +1192,7 @@ mod tests {
11801192
use arrow::array::{Float64Array, UInt32Array};
11811193
use arrow::compute::{concat_batches, SortOptions};
11821194
use arrow::datatypes::DataType;
1195+
use arrow_array::{Float32Array, Int32Array};
11831196
use datafusion_common::{
11841197
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
11851198
ScalarValue,
@@ -1195,7 +1208,9 @@ mod tests {
11951208
use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg};
11961209
use datafusion_physical_expr::PhysicalSortExpr;
11971210

1211+
use crate::common::collect;
11981212
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
1213+
use datafusion_physical_expr_common::expressions::Literal;
11991214
use futures::{FutureExt, Stream};
12001215

12011216
// Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -2267,4 +2282,94 @@ mod tests {
22672282
assert_eq!(new_agg.schema(), aggregate_exec.schema());
22682283
Ok(())
22692284
}
2285+
2286+
#[tokio::test]
2287+
async fn test_agg_exec_group_by_const() -> Result<()> {
2288+
let schema = Arc::new(Schema::new(vec![
2289+
Field::new("a", DataType::Float32, true),
2290+
Field::new("b", DataType::Float32, true),
2291+
Field::new("const", DataType::Int32, false),
2292+
]));
2293+
2294+
let col_a = col("a", &schema)?;
2295+
let col_b = col("b", &schema)?;
2296+
let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
2297+
2298+
let groups = PhysicalGroupBy::new(
2299+
vec![
2300+
(col_a, "a".to_string()),
2301+
(col_b, "b".to_string()),
2302+
(const_expr, "const".to_string()),
2303+
],
2304+
vec![
2305+
(
2306+
Arc::new(Literal::new(ScalarValue::Float32(None))),
2307+
"a".to_string(),
2308+
),
2309+
(
2310+
Arc::new(Literal::new(ScalarValue::Float32(None))),
2311+
"b".to_string(),
2312+
),
2313+
(
2314+
Arc::new(Literal::new(ScalarValue::Int32(None))),
2315+
"const".to_string(),
2316+
),
2317+
],
2318+
vec![
2319+
vec![false, true, true],
2320+
vec![true, false, true],
2321+
vec![true, true, false],
2322+
],
2323+
);
2324+
2325+
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
2326+
count_udaf().as_ref(),
2327+
&[lit(1)],
2328+
&[datafusion_expr::lit(1)],
2329+
&[],
2330+
&[],
2331+
schema.as_ref(),
2332+
"1",
2333+
false,
2334+
false,
2335+
)?];
2336+
2337+
let input_batches = (0..4)
2338+
.map(|_| {
2339+
let a = Arc::new(Float32Array::from(vec![0.; 8192]));
2340+
let b = Arc::new(Float32Array::from(vec![0.; 8192]));
2341+
let c = Arc::new(Int32Array::from(vec![1; 8192]));
2342+
2343+
RecordBatch::try_new(schema.clone(), vec![a, b, c]).unwrap()
2344+
})
2345+
.collect();
2346+
2347+
let input =
2348+
Arc::new(MemoryExec::try_new(&[input_batches], schema.clone(), None)?);
2349+
2350+
let aggregate_exec = Arc::new(AggregateExec::try_new(
2351+
AggregateMode::Partial,
2352+
groups,
2353+
aggregates.clone(),
2354+
vec![None],
2355+
input,
2356+
schema,
2357+
)?);
2358+
2359+
let output =
2360+
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;
2361+
2362+
let expected = [
2363+
"+-----+-----+-------+----------+",
2364+
"| a | b | const | 1[count] |",
2365+
"+-----+-----+-------+----------+",
2366+
"| | 0.0 | | 32768 |",
2367+
"| 0.0 | | | 32768 |",
2368+
"| | | 1 | 32768 |",
2369+
"+-----+-----+-------+----------+",
2370+
];
2371+
assert_batches_sorted_eq!(expected, &output);
2372+
2373+
Ok(())
2374+
}
22702375
}

0 commit comments

Comments
 (0)