Skip to content

Commit f4e519f

Browse files
authored
Move min and max to user defined aggregate function, remove AggregateFunction / AggregateFunctionDefinition::BuiltIn (#11013)
* Moving min and max to new API and removing from protobuf * Using input_type rather than data_type * Adding type coercion * Fixed doctests * Implementing feedback from code review * Implementing feedback from code review * Fixed wrong name * Fixing name
1 parent 9e90e17 commit f4e519f

File tree

56 files changed

+937
-1813
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+937
-1813
lines changed

datafusion-examples/examples/dataframe_subquery.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::sync::Arc;
2020

2121
use datafusion::error::Result;
2222
use datafusion::functions_aggregate::average::avg;
23+
use datafusion::functions_aggregate::min_max::max;
2324
use datafusion::prelude::*;
2425
use datafusion::test_util::arrow_test_data;
2526
use datafusion_common::ScalarValue;

datafusion/core/src/dataframe/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,11 @@ use datafusion_common::{
5353
};
5454
use datafusion_expr::{case, is_null, lit};
5555
use datafusion_expr::{
56-
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
56+
utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
57+
};
58+
use datafusion_functions_aggregate::expr_fn::{
59+
avg, count, max, median, min, stddev, sum,
5760
};
58-
use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum};
5961

6062
use async_trait::async_trait;
6163
use datafusion_catalog::Session;
@@ -144,6 +146,7 @@ impl Default for DataFrameWriteOptions {
144146
/// ```
145147
/// # use datafusion::prelude::*;
146148
/// # use datafusion::error::Result;
149+
/// # use datafusion::functions_aggregate::expr_fn::min;
147150
/// # #[tokio::main]
148151
/// # async fn main() -> Result<()> {
149152
/// let ctx = SessionContext::new();
@@ -407,6 +410,7 @@ impl DataFrame {
407410
/// ```
408411
/// # use datafusion::prelude::*;
409412
/// # use datafusion::error::Result;
413+
/// # use datafusion::functions_aggregate::expr_fn::min;
410414
/// # #[tokio::main]
411415
/// # async fn main() -> Result<()> {
412416
/// let ctx = SessionContext::new();

datafusion/core/src/datasource/file_format/parquet.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ use datafusion_common::{
5050
use datafusion_common_runtime::SpawnedTask;
5151
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation};
5252
use datafusion_execution::TaskContext;
53-
use datafusion_physical_expr::expressions::{MaxAccumulator, MinAccumulator};
53+
use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
5454
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
5555
use datafusion_physical_plan::metrics::MetricsSet;
5656

datafusion/core/src/datasource/statistics.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use super::listing::PartitionedFile;
1919
use crate::arrow::datatypes::{Schema, SchemaRef};
2020
use crate::error::Result;
21-
use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator};
21+
use crate::functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
2222
use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics};
2323
use arrow_schema::DataType;
2424

datafusion/core/src/execution/context/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ where
144144
///
145145
/// ```
146146
/// use datafusion::prelude::*;
147+
/// # use datafusion::functions_aggregate::expr_fn::min;
147148
/// # use datafusion::{error::Result, assert_batches_eq};
148149
/// # #[tokio::main]
149150
/// # async fn main() -> Result<()> {

datafusion/core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
//! ```rust
5353
//! # use datafusion::prelude::*;
5454
//! # use datafusion::error::Result;
55+
//! # use datafusion::functions_aggregate::expr_fn::min;
5556
//! # use datafusion::arrow::record_batch::RecordBatch;
5657
//!
5758
//! # #[tokio::main]

datafusion/core/src/physical_optimizer/aggregate_statistics.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -272,39 +272,28 @@ fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool {
272272
return true;
273273
}
274274
}
275-
276275
false
277276
}
278277

279278
// TODO: Move this check into AggregateUDFImpl
280279
// https://github.com/apache/datafusion/issues/11153
281280
fn is_min(agg_expr: &dyn AggregateExpr) -> bool {
282-
if agg_expr.as_any().is::<expressions::Min>() {
283-
return true;
284-
}
285-
286281
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
287-
if agg_expr.fun().name() == "min" {
282+
if agg_expr.fun().name().to_lowercase() == "min" {
288283
return true;
289284
}
290285
}
291-
292286
false
293287
}
294288

295289
// TODO: Move this check into AggregateUDFImpl
296290
// https://github.com/apache/datafusion/issues/11153
297291
fn is_max(agg_expr: &dyn AggregateExpr) -> bool {
298-
if agg_expr.as_any().is::<expressions::Max>() {
299-
return true;
300-
}
301-
302292
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
303-
if agg_expr.fun().name() == "max" {
293+
if agg_expr.fun().name().to_lowercase() == "max" {
304294
return true;
305295
}
306296
}
307-
308297
false
309298
}
310299

datafusion/core/src/physical_planner.rs

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ use crate::physical_plan::unnest::UnnestExec;
5959
use crate::physical_plan::values::ValuesExec;
6060
use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
6161
use crate::physical_plan::{
62-
aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan,
63-
ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr,
62+
displayable, udaf, windows, AggregateExpr, ExecutionPlan, ExecutionPlanProperties,
63+
InputOrderMode, Partitioning, PhysicalExpr, WindowExpr,
6464
};
6565

6666
use arrow::compute::SortOptions;
@@ -1812,7 +1812,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18121812
e: &Expr,
18131813
name: impl Into<String>,
18141814
logical_input_schema: &DFSchema,
1815-
physical_input_schema: &Schema,
1815+
_physical_input_schema: &Schema,
18161816
execution_props: &ExecutionProps,
18171817
) -> Result<AggregateExprWithOptionalArgs> {
18181818
match e {
@@ -1840,28 +1840,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
18401840
== NullTreatment::IgnoreNulls;
18411841

18421842
let (agg_expr, filter, order_by) = match func_def {
1843-
AggregateFunctionDefinition::BuiltIn(fun) => {
1844-
let physical_sort_exprs = match order_by {
1845-
Some(exprs) => Some(create_physical_sort_exprs(
1846-
exprs,
1847-
logical_input_schema,
1848-
execution_props,
1849-
)?),
1850-
None => None,
1851-
};
1852-
let ordering_reqs: Vec<PhysicalSortExpr> =
1853-
physical_sort_exprs.clone().unwrap_or(vec![]);
1854-
let agg_expr = aggregates::create_aggregate_expr(
1855-
fun,
1856-
*distinct,
1857-
&physical_args,
1858-
&ordering_reqs,
1859-
physical_input_schema,
1860-
name,
1861-
ignore_nulls,
1862-
)?;
1863-
(agg_expr, filter, physical_sort_exprs)
1864-
}
18651843
AggregateFunctionDefinition::UDF(fun) => {
18661844
let sort_exprs = order_by.clone().unwrap_or(vec![]);
18671845
let physical_sort_exprs = match order_by {

datafusion/core/tests/dataframe/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
5454
use datafusion_expr::expr::{GroupingSet, Sort};
5555
use datafusion_expr::var_provider::{VarProvider, VarType};
5656
use datafusion_expr::{
57-
cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
58-
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame,
59-
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
57+
cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery,
58+
when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound,
59+
WindowFrameUnits, WindowFunctionDefinition,
6060
};
61-
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum};
61+
use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum};
6262

6363
#[tokio::test]
6464
async fn test_count_wildcard_on_sort() -> Result<()> {

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ use datafusion::physical_plan::{collect, InputOrderMode};
3232
use datafusion::prelude::{SessionConfig, SessionContext};
3333
use datafusion_common::{Result, ScalarValue};
3434
use datafusion_common_runtime::SpawnedTask;
35-
use datafusion_expr::type_coercion::aggregates::coerce_types;
3635
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
3736
use datafusion_expr::{
38-
AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound,
39-
WindowFrameUnits, WindowFunctionDefinition,
37+
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
38+
WindowFunctionDefinition,
4039
};
4140
use datafusion_functions_aggregate::count::count_udaf;
41+
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
4242
use datafusion_functions_aggregate::sum::sum_udaf;
4343
use datafusion_physical_expr::expressions::{cast, col, lit};
4444
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
@@ -361,14 +361,14 @@ fn get_random_function(
361361
window_fn_map.insert(
362362
"min",
363363
(
364-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min),
364+
WindowFunctionDefinition::AggregateUDF(min_udaf()),
365365
vec![arg.clone()],
366366
),
367367
);
368368
window_fn_map.insert(
369369
"max",
370370
(
371-
WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max),
371+
WindowFunctionDefinition::AggregateUDF(max_udaf()),
372372
vec![arg.clone()],
373373
),
374374
);
@@ -465,16 +465,7 @@ fn get_random_function(
465465
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
466466
let (window_fn, args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
467467
let mut args = args.clone();
468-
if let WindowFunctionDefinition::AggregateFunction(f) = window_fn {
469-
if !args.is_empty() {
470-
// Do type coercion first argument
471-
let a = args[0].clone();
472-
let dt = a.data_type(schema.as_ref()).unwrap();
473-
let sig = f.signature();
474-
let coerced = coerce_types(f, &[dt], &sig).unwrap();
475-
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
476-
}
477-
} else if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
468+
if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn {
478469
if !args.is_empty() {
479470
// Do type coercion first argument
480471
let a = args[0].clone();

datafusion/expr/src/aggregate_function.rs

Lines changed: 0 additions & 156 deletions
This file was deleted.

0 commit comments

Comments
 (0)