@@ -369,14 +369,26 @@ impl AggregateExec {
369
369
new_requirement. extend ( req) ;
370
370
new_requirement = collapse_lex_req ( new_requirement) ;
371
371
372
- let input_order_mode =
373
- if indices. len ( ) == groupby_exprs. len ( ) && !indices. is_empty ( ) {
374
- InputOrderMode :: Sorted
375
- } else if !indices. is_empty ( ) {
376
- InputOrderMode :: PartiallySorted ( indices)
377
- } else {
378
- InputOrderMode :: Linear
379
- } ;
372
+ // If our aggregation has grouping sets then our base grouping exprs will
373
+ // be expanded based on the flags in `group_by.groups` where for each
374
+ // group we swap the grouping expr for `null` if the flag is `true`
375
+ // That means that each index in `indices` is valid if and only if
376
+ // it is not null in every group
377
+ let indices: Vec < usize > = indices
378
+ . into_iter ( )
379
+ . filter ( |idx| group_by. groups . iter ( ) . all ( |group| !group[ * idx] ) )
380
+ . collect ( ) ;
381
+
382
+ let input_order_mode = if indices. len ( ) == groupby_exprs. len ( )
383
+ && !indices. is_empty ( )
384
+ && group_by. groups . len ( ) == 1
385
+ {
386
+ InputOrderMode :: Sorted
387
+ } else if !indices. is_empty ( ) {
388
+ InputOrderMode :: PartiallySorted ( indices)
389
+ } else {
390
+ InputOrderMode :: Linear
391
+ } ;
380
392
381
393
// construct a map from the input expression to the output expression of the Aggregation group by
382
394
let projection_mapping =
@@ -1180,6 +1192,7 @@ mod tests {
1180
1192
use arrow:: array:: { Float64Array , UInt32Array } ;
1181
1193
use arrow:: compute:: { concat_batches, SortOptions } ;
1182
1194
use arrow:: datatypes:: DataType ;
1195
+ use arrow_array:: { Float32Array , Int32Array } ;
1183
1196
use datafusion_common:: {
1184
1197
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError ,
1185
1198
ScalarValue ,
@@ -1195,7 +1208,9 @@ mod tests {
1195
1208
use datafusion_physical_expr:: expressions:: { lit, OrderSensitiveArrayAgg } ;
1196
1209
use datafusion_physical_expr:: PhysicalSortExpr ;
1197
1210
1211
+ use crate :: common:: collect;
1198
1212
use datafusion_physical_expr_common:: aggregate:: create_aggregate_expr;
1213
+ use datafusion_physical_expr_common:: expressions:: Literal ;
1199
1214
use futures:: { FutureExt , Stream } ;
1200
1215
1201
1216
// Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -2267,4 +2282,94 @@ mod tests {
2267
2282
assert_eq ! ( new_agg. schema( ) , aggregate_exec. schema( ) ) ;
2268
2283
Ok ( ( ) )
2269
2284
}
2285
+
2286
+ #[ tokio:: test]
2287
+ async fn test_agg_exec_group_by_const ( ) -> Result < ( ) > {
2288
+ let schema = Arc :: new ( Schema :: new ( vec ! [
2289
+ Field :: new( "a" , DataType :: Float32 , true ) ,
2290
+ Field :: new( "b" , DataType :: Float32 , true ) ,
2291
+ Field :: new( "const" , DataType :: Int32 , false ) ,
2292
+ ] ) ) ;
2293
+
2294
+ let col_a = col ( "a" , & schema) ?;
2295
+ let col_b = col ( "b" , & schema) ?;
2296
+ let const_expr = Arc :: new ( Literal :: new ( ScalarValue :: Int32 ( Some ( 1 ) ) ) ) ;
2297
+
2298
+ let groups = PhysicalGroupBy :: new (
2299
+ vec ! [
2300
+ ( col_a, "a" . to_string( ) ) ,
2301
+ ( col_b, "b" . to_string( ) ) ,
2302
+ ( const_expr, "const" . to_string( ) ) ,
2303
+ ] ,
2304
+ vec ! [
2305
+ (
2306
+ Arc :: new( Literal :: new( ScalarValue :: Float32 ( None ) ) ) ,
2307
+ "a" . to_string( ) ,
2308
+ ) ,
2309
+ (
2310
+ Arc :: new( Literal :: new( ScalarValue :: Float32 ( None ) ) ) ,
2311
+ "b" . to_string( ) ,
2312
+ ) ,
2313
+ (
2314
+ Arc :: new( Literal :: new( ScalarValue :: Int32 ( None ) ) ) ,
2315
+ "const" . to_string( ) ,
2316
+ ) ,
2317
+ ] ,
2318
+ vec ! [
2319
+ vec![ false , true , true ] ,
2320
+ vec![ true , false , true ] ,
2321
+ vec![ true , true , false ] ,
2322
+ ] ,
2323
+ ) ;
2324
+
2325
+ let aggregates: Vec < Arc < dyn AggregateExpr > > = vec ! [ create_aggregate_expr(
2326
+ count_udaf( ) . as_ref( ) ,
2327
+ & [ lit( 1 ) ] ,
2328
+ & [ datafusion_expr:: lit( 1 ) ] ,
2329
+ & [ ] ,
2330
+ & [ ] ,
2331
+ schema. as_ref( ) ,
2332
+ "1" ,
2333
+ false ,
2334
+ false ,
2335
+ ) ?] ;
2336
+
2337
+ let input_batches = ( 0 ..4 )
2338
+ . map ( |_| {
2339
+ let a = Arc :: new ( Float32Array :: from ( vec ! [ 0. ; 8192 ] ) ) ;
2340
+ let b = Arc :: new ( Float32Array :: from ( vec ! [ 0. ; 8192 ] ) ) ;
2341
+ let c = Arc :: new ( Int32Array :: from ( vec ! [ 1 ; 8192 ] ) ) ;
2342
+
2343
+ RecordBatch :: try_new ( schema. clone ( ) , vec ! [ a, b, c] ) . unwrap ( )
2344
+ } )
2345
+ . collect ( ) ;
2346
+
2347
+ let input =
2348
+ Arc :: new ( MemoryExec :: try_new ( & [ input_batches] , schema. clone ( ) , None ) ?) ;
2349
+
2350
+ let aggregate_exec = Arc :: new ( AggregateExec :: try_new (
2351
+ AggregateMode :: Partial ,
2352
+ groups,
2353
+ aggregates. clone ( ) ,
2354
+ vec ! [ None ] ,
2355
+ input,
2356
+ schema,
2357
+ ) ?) ;
2358
+
2359
+ let output =
2360
+ collect ( aggregate_exec. execute ( 0 , Arc :: new ( TaskContext :: default ( ) ) ) ?) . await ?;
2361
+
2362
+ let expected = [
2363
+ "+-----+-----+-------+----------+" ,
2364
+ "| a | b | const | 1[count] |" ,
2365
+ "+-----+-----+-------+----------+" ,
2366
+ "| | 0.0 | | 32768 |" ,
2367
+ "| 0.0 | | | 32768 |" ,
2368
+ "| | | 1 | 32768 |" ,
2369
+ "+-----+-----+-------+----------+" ,
2370
+ ] ;
2371
+ assert_batches_sorted_eq ! ( expected, & output) ;
2372
+
2373
+ Ok ( ( ) )
2374
+ }
2270
2375
}
0 commit comments