Skip to content

AggregateUDFImpl::schema_name and AggregateUDFImpl::display_name for customizable name #14695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,11 +423,11 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
// In real-world scenarios, you might create UDFs from built-in expressions.
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
Arc::new(AggregateUDF::from(GeoMeanUdaf::new())),
aggregate_function.args,
aggregate_function.distinct,
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
aggregate_function.params.args,
aggregate_function.params.distinct,
aggregate_function.params.filter,
aggregate_function.params.order_by,
aggregate_function.params.null_treatment,
)))
};
Some(Box::new(simplify))
Expand Down
16 changes: 10 additions & 6 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ use datafusion_common::{
};
use datafusion_expr::dml::{CopyTo, InsertOp};
use datafusion_expr::expr::{
physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction,
physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet,
WindowFunction,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
Expand Down Expand Up @@ -1579,11 +1580,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
match e {
Expr::AggregateFunction(AggregateFunction {
func,
distinct,
args,
filter,
order_by,
null_treatment,
params:
AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
},
}) => {
let name = if let Some(name) = name {
name
Expand Down
14 changes: 8 additions & 6 deletions datafusion/core/tests/execution/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans};
use datafusion_execution::TaskContext;
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::logical_plan::{LogicalPlan, Values};
use datafusion_expr::{Aggregate, AggregateUDF, Expr};
use datafusion_functions_aggregate::count::Count;
Expand Down Expand Up @@ -60,11 +60,13 @@ async fn count_only_nulls() -> Result<()> {
vec![],
vec![Expr::AggregateFunction(AggregateFunction {
func: Arc::new(AggregateUDF::new_from_impl(Count::new())),
args: vec![input_col_ref],
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
params: AggregateFunctionParams {
args: vec![input_col_ref],
distinct: false,
filter: None,
order_by: None,
null_treatment: None,
},
})],
)?);

Expand Down
120 changes: 52 additions & 68 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,11 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort {
pub struct AggregateFunction {
/// Name of the function
pub func: Arc<crate::AggregateUDF>,
/// List of expressions to feed to the functions as arguments
pub params: AggregateFunctionParams,
}

#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct AggregateFunctionParams {
pub args: Vec<Expr>,
/// Whether this is a DISTINCT aggregation or not
pub distinct: bool,
Expand All @@ -719,11 +723,13 @@ impl AggregateFunction {
) -> Self {
Self {
func,
args,
distinct,
filter,
order_by,
null_treatment,
params: AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
},
}
}
}
Expand Down Expand Up @@ -1864,19 +1870,25 @@ impl NormalizeEq for Expr {
(
Expr::AggregateFunction(AggregateFunction {
func: self_func,
args: self_args,
distinct: self_distinct,
filter: self_filter,
order_by: self_order_by,
null_treatment: self_null_treatment,
params:
AggregateFunctionParams {
args: self_args,
distinct: self_distinct,
filter: self_filter,
order_by: self_order_by,
null_treatment: self_null_treatment,
},
}),
Expr::AggregateFunction(AggregateFunction {
func: other_func,
args: other_args,
distinct: other_distinct,
filter: other_filter,
order_by: other_order_by,
null_treatment: other_null_treatment,
params:
AggregateFunctionParams {
args: other_args,
distinct: other_distinct,
filter: other_filter,
order_by: other_order_by,
null_treatment: other_null_treatment,
},
}),
) => {
self_func.name() == other_func.name()
Expand Down Expand Up @@ -2154,11 +2166,14 @@ impl HashNode for Expr {
}
Expr::AggregateFunction(AggregateFunction {
func,
args: _args,
distinct,
filter: _filter,
order_by: _order_by,
null_treatment,
params:
AggregateFunctionParams {
args: _args,
distinct,
filter: _,
order_by: _,
null_treatment,
},
}) => {
func.hash(state);
distinct.hash(state);
Expand Down Expand Up @@ -2264,35 +2279,15 @@ impl Display for SchemaDisplay<'_> {
| Expr::Placeholder(_)
| Expr::Wildcard { .. } => write!(f, "{}", self.0),

Expr::AggregateFunction(AggregateFunction {
func,
args,
distinct,
filter,
order_by,
null_treatment,
}) => {
write!(
f,
"{}({}{})",
func.name(),
if *distinct { "DISTINCT " } else { "" },
schema_name_from_exprs_comma_separated_without_space(args)?
)?;

if let Some(null_treatment) = null_treatment {
write!(f, " {}", null_treatment)?;
Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.schema_name(params) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Ok(name) => {
write!(f, "{name}")
}
Err(e) => {
write!(f, "got error from schema_name {}", e)
}
}

if let Some(filter) = filter {
write!(f, " FILTER (WHERE {filter})")?;
};

if let Some(order_by) = order_by {
write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?;
};

Ok(())
}
// Expr is not shown since it is aliased
Expr::Alias(Alias {
Expand Down Expand Up @@ -2653,26 +2648,15 @@ impl Display for Expr {
)?;
Ok(())
}
Expr::AggregateFunction(AggregateFunction {
func,
distinct,
ref args,
filter,
order_by,
null_treatment,
..
}) => {
fmt_function(f, func.name(), *distinct, args, true)?;
if let Some(nt) = null_treatment {
write!(f, " {}", nt)?;
}
if let Some(fe) = filter {
write!(f, " FILTER (WHERE {fe})")?;
}
if let Some(ob) = order_by {
write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?;
Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.display_name(params) {
Ok(name) => {
write!(f, "{}", name)
}
Err(e) => {
write!(f, "got error from display_name {}", e)
}
}
Ok(())
}
Expr::Between(Between {
expr,
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,10 +830,10 @@ impl ExprFuncBuilder {

let fun_expr = match fun {
ExprFuncKind::Aggregate(mut udaf) => {
udaf.order_by = order_by;
udaf.filter = filter.map(Box::new);
udaf.distinct = distinct;
udaf.null_treatment = null_treatment;
udaf.params.order_by = order_by;
udaf.params.filter = filter.map(Box::new);
udaf.params.distinct = distinct;
udaf.params.null_treatment = null_treatment;
Expr::AggregateFunction(udaf)
}
ExprFuncKind::Window(mut udwf) => {
Expand Down
9 changes: 6 additions & 3 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use super::{Between, Expr, Like};
use crate::expr::{
AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder,
ScalarFunction, TryCast, Unnest, WindowFunction,
AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
};
use crate::type_coercion::functions::{
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
Expand Down Expand Up @@ -153,7 +153,10 @@ impl ExprSchemable for Expr {
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema, window_function)
.map(|(return_type, _)| return_type),
Expr::AggregateFunction(AggregateFunction { func, args, .. }) => {
Expr::AggregateFunction(AggregateFunction {
func,
params: AggregateFunctionParams { args, .. },
}) => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
Expand Down
20 changes: 12 additions & 8 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
//! Tree node implementation for Logical Expressions

use crate::expr::{
AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList,
InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast,
GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest,
WindowFunction,
};
use crate::{Expr, ExprFunctionExt};

Expand Down Expand Up @@ -87,7 +88,7 @@ impl TreeNode for Expr {
}) => (expr, low, high).apply_ref_elements(f),
Expr::Case(Case { expr, when_then_expr, else_expr }) =>
(expr, when_then_expr, else_expr).apply_ref_elements(f),
Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) =>
Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) =>
(args, filter, order_by).apply_ref_elements(f),
Expr::WindowFunction(WindowFunction {
args,
Expand Down Expand Up @@ -241,12 +242,15 @@ impl TreeNode for Expr {
},
),
Expr::AggregateFunction(AggregateFunction {
args,
func,
distinct,
filter,
order_by,
null_treatment,
params:
AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
},
}) => (args, filter, order_by).map_elements(f)?.map_data(
|(new_args, new_filter, new_order_by)| {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
Expand Down
Loading
Loading