Skip to content

Commit ece4555

Browse files
authored
AggregateUDFImpl::schema_name and AggregateUDFImpl::display_name for customizable name (#14695)
* udaf schema_name * doc * fix proto * fmt * fix * fmt * doc * add displayname * doc
1 parent 8b45d2d commit ece4555

File tree

16 files changed

+280
-159
lines changed

16 files changed

+280
-159
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,11 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
423423
// In real-world scenarios, you might create UDFs from built-in expressions.
424424
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
425425
Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
426-
aggregate_function.args,
427-
aggregate_function.distinct,
428-
aggregate_function.filter,
429-
aggregate_function.order_by,
430-
aggregate_function.null_treatment,
426+
aggregate_function.params.args,
427+
aggregate_function.params.distinct,
428+
aggregate_function.params.filter,
429+
aggregate_function.params.order_by,
430+
aggregate_function.params.null_treatment,
431431
)))
432432
};
433433
Some(Box::new(simplify))

datafusion/core/src/physical_planner.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ use datafusion_common::{
7070
};
7171
use datafusion_expr::dml::{CopyTo, InsertOp};
7272
use datafusion_expr::expr::{
73-
physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction,
73+
physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet,
74+
WindowFunction,
7475
};
7576
use datafusion_expr::expr_rewriter::unnormalize_cols;
7677
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
@@ -1579,11 +1580,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
15791580
match e {
15801581
Expr::AggregateFunction(AggregateFunction {
15811582
func,
1582-
distinct,
1583-
args,
1584-
filter,
1585-
order_by,
1586-
null_treatment,
1583+
params:
1584+
AggregateFunctionParams {
1585+
args,
1586+
distinct,
1587+
filter,
1588+
order_by,
1589+
null_treatment,
1590+
},
15871591
}) => {
15881592
let name = if let Some(name) = name {
15891593
name

datafusion/core/tests/execution/logical_plan.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field};
2020
use datafusion::execution::session_state::SessionStateBuilder;
2121
use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans};
2222
use datafusion_execution::TaskContext;
23-
use datafusion_expr::expr::AggregateFunction;
23+
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
2424
use datafusion_expr::logical_plan::{LogicalPlan, Values};
2525
use datafusion_expr::{Aggregate, AggregateUDF, Expr};
2626
use datafusion_functions_aggregate::count::Count;
@@ -60,11 +60,13 @@ async fn count_only_nulls() -> Result<()> {
6060
vec![],
6161
vec![Expr::AggregateFunction(AggregateFunction {
6262
func: Arc::new(AggregateUDF::new_from_impl(Count::new())),
63-
args: vec![input_col_ref],
64-
distinct: false,
65-
filter: None,
66-
order_by: None,
67-
null_treatment: None,
63+
params: AggregateFunctionParams {
64+
args: vec![input_col_ref],
65+
distinct: false,
66+
filter: None,
67+
order_by: None,
68+
null_treatment: None,
69+
},
6870
})],
6971
)?);
7072

datafusion/expr/src/expr.rs

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,11 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort {
696696
pub struct AggregateFunction {
697697
/// Name of the function
698698
pub func: Arc<crate::AggregateUDF>,
699-
/// List of expressions to feed to the functions as arguments
699+
pub params: AggregateFunctionParams,
700+
}
701+
702+
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
703+
pub struct AggregateFunctionParams {
700704
pub args: Vec<Expr>,
701705
/// Whether this is a DISTINCT aggregation or not
702706
pub distinct: bool,
@@ -719,11 +723,13 @@ impl AggregateFunction {
719723
) -> Self {
720724
Self {
721725
func,
722-
args,
723-
distinct,
724-
filter,
725-
order_by,
726-
null_treatment,
726+
params: AggregateFunctionParams {
727+
args,
728+
distinct,
729+
filter,
730+
order_by,
731+
null_treatment,
732+
},
727733
}
728734
}
729735
}
@@ -1864,19 +1870,25 @@ impl NormalizeEq for Expr {
18641870
(
18651871
Expr::AggregateFunction(AggregateFunction {
18661872
func: self_func,
1867-
args: self_args,
1868-
distinct: self_distinct,
1869-
filter: self_filter,
1870-
order_by: self_order_by,
1871-
null_treatment: self_null_treatment,
1873+
params:
1874+
AggregateFunctionParams {
1875+
args: self_args,
1876+
distinct: self_distinct,
1877+
filter: self_filter,
1878+
order_by: self_order_by,
1879+
null_treatment: self_null_treatment,
1880+
},
18721881
}),
18731882
Expr::AggregateFunction(AggregateFunction {
18741883
func: other_func,
1875-
args: other_args,
1876-
distinct: other_distinct,
1877-
filter: other_filter,
1878-
order_by: other_order_by,
1879-
null_treatment: other_null_treatment,
1884+
params:
1885+
AggregateFunctionParams {
1886+
args: other_args,
1887+
distinct: other_distinct,
1888+
filter: other_filter,
1889+
order_by: other_order_by,
1890+
null_treatment: other_null_treatment,
1891+
},
18801892
}),
18811893
) => {
18821894
self_func.name() == other_func.name()
@@ -2154,11 +2166,14 @@ impl HashNode for Expr {
21542166
}
21552167
Expr::AggregateFunction(AggregateFunction {
21562168
func,
2157-
args: _args,
2158-
distinct,
2159-
filter: _filter,
2160-
order_by: _order_by,
2161-
null_treatment,
2169+
params:
2170+
AggregateFunctionParams {
2171+
args: _args,
2172+
distinct,
2173+
filter: _,
2174+
order_by: _,
2175+
null_treatment,
2176+
},
21622177
}) => {
21632178
func.hash(state);
21642179
distinct.hash(state);
@@ -2264,35 +2279,15 @@ impl Display for SchemaDisplay<'_> {
22642279
| Expr::Placeholder(_)
22652280
| Expr::Wildcard { .. } => write!(f, "{}", self.0),
22662281

2267-
Expr::AggregateFunction(AggregateFunction {
2268-
func,
2269-
args,
2270-
distinct,
2271-
filter,
2272-
order_by,
2273-
null_treatment,
2274-
}) => {
2275-
write!(
2276-
f,
2277-
"{}({}{})",
2278-
func.name(),
2279-
if *distinct { "DISTINCT " } else { "" },
2280-
schema_name_from_exprs_comma_separated_without_space(args)?
2281-
)?;
2282-
2283-
if let Some(null_treatment) = null_treatment {
2284-
write!(f, " {}", null_treatment)?;
2282+
Expr::AggregateFunction(AggregateFunction { func, params }) => {
2283+
match func.schema_name(params) {
2284+
Ok(name) => {
2285+
write!(f, "{name}")
2286+
}
2287+
Err(e) => {
2288+
write!(f, "got error from schema_name {}", e)
2289+
}
22852290
}
2286-
2287-
if let Some(filter) = filter {
2288-
write!(f, " FILTER (WHERE {filter})")?;
2289-
};
2290-
2291-
if let Some(order_by) = order_by {
2292-
write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?;
2293-
};
2294-
2295-
Ok(())
22962291
}
22972292
// Expr is not shown since it is aliased
22982293
Expr::Alias(Alias {
@@ -2653,26 +2648,15 @@ impl Display for Expr {
26532648
)?;
26542649
Ok(())
26552650
}
2656-
Expr::AggregateFunction(AggregateFunction {
2657-
func,
2658-
distinct,
2659-
ref args,
2660-
filter,
2661-
order_by,
2662-
null_treatment,
2663-
..
2664-
}) => {
2665-
fmt_function(f, func.name(), *distinct, args, true)?;
2666-
if let Some(nt) = null_treatment {
2667-
write!(f, " {}", nt)?;
2668-
}
2669-
if let Some(fe) = filter {
2670-
write!(f, " FILTER (WHERE {fe})")?;
2671-
}
2672-
if let Some(ob) = order_by {
2673-
write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?;
2651+
Expr::AggregateFunction(AggregateFunction { func, params }) => {
2652+
match func.display_name(params) {
2653+
Ok(name) => {
2654+
write!(f, "{}", name)
2655+
}
2656+
Err(e) => {
2657+
write!(f, "got error from display_name {}", e)
2658+
}
26742659
}
2675-
Ok(())
26762660
}
26772661
Expr::Between(Between {
26782662
expr,

datafusion/expr/src/expr_fn.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -826,10 +826,10 @@ impl ExprFuncBuilder {
826826

827827
let fun_expr = match fun {
828828
ExprFuncKind::Aggregate(mut udaf) => {
829-
udaf.order_by = order_by;
830-
udaf.filter = filter.map(Box::new);
831-
udaf.distinct = distinct;
832-
udaf.null_treatment = null_treatment;
829+
udaf.params.order_by = order_by;
830+
udaf.params.filter = filter.map(Box::new);
831+
udaf.params.distinct = distinct;
832+
udaf.params.null_treatment = null_treatment;
833833
Expr::AggregateFunction(udaf)
834834
}
835835
ExprFuncKind::Window(mut udwf) => {

datafusion/expr/src/expr_schema.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
use super::{Between, Expr, Like};
1919
use crate::expr::{
20-
AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder,
21-
ScalarFunction, TryCast, Unnest, WindowFunction,
20+
AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21+
InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
2222
};
2323
use crate::type_coercion::functions::{
2424
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
@@ -153,7 +153,10 @@ impl ExprSchemable for Expr {
153153
Expr::WindowFunction(window_function) => self
154154
.data_type_and_nullable_with_window_function(schema, window_function)
155155
.map(|(return_type, _)| return_type),
156-
Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
156+
Expr::AggregateFunction(AggregateFunction {
157+
func,
158+
params: AggregateFunctionParams { args, .. },
159+
}) => {
157160
let data_types = args
158161
.iter()
159162
.map(|e| e.get_type(schema))

datafusion/expr/src/tree_node.rs

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
//! Tree node implementation for Logical Expressions
1919
2020
use crate::expr::{
21-
AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList,
22-
InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
21+
AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
22+
GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
23+
WindowFunction,
2324
};
2425
use crate::{Expr, ExprFunctionExt};
2526

@@ -87,7 +88,7 @@ impl TreeNode for Expr {
8788
}) => (expr, low, high).apply_ref_elements(f),
8889
Expr::Case(Case { expr, when_then_expr, else_expr }) =>
8990
(expr, when_then_expr, else_expr).apply_ref_elements(f),
90-
Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) =>
91+
Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
9192
(args, filter, order_by).apply_ref_elements(f),
9293
Expr::WindowFunction(WindowFunction {
9394
args,
@@ -241,12 +242,15 @@ impl TreeNode for Expr {
241242
},
242243
),
243244
Expr::AggregateFunction(AggregateFunction {
244-
args,
245245
func,
246-
distinct,
247-
filter,
248-
order_by,
249-
null_treatment,
246+
params:
247+
AggregateFunctionParams {
248+
args,
249+
distinct,
250+
filter,
251+
order_by,
252+
null_treatment,
253+
},
250254
}) => (args, filter, order_by).map_elements(f)?.map_data(
251255
|(new_args, new_filter, new_order_by)| {
252256
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(

0 commit comments

Comments
 (0)