Skip to content

Commit 502ce4b

Browse files
Blizzaraalamb
andauthored
fix: UDF, UDAF, UDWF with_alias(..) should wrap the inner function fully (#12098)
* fix: UDF, UDAF, UDWF with_alias(..) should wrap the inner function fully * revert back to having Arc<Self> * add notes about adding stuff into Aliased impls * fix clippy --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 9a1a92d commit 502ce4b

File tree

3 files changed

+138
-2
lines changed

3 files changed

+138
-2
lines changed

datafusion/expr/src/udaf.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ where
337337
/// let expr = geometric_mean.call(vec![col("a")]);
338338
/// ```
339339
pub trait AggregateUDFImpl: Debug + Send + Sync {
340+
// Note: When adding any methods (with default implementations), remember to add them also
341+
// into the AliasedAggregateUDFImpl below!
342+
340343
/// Returns this object as an [`Any`] trait object
341344
fn as_any(&self) -> &dyn Any;
342345

@@ -635,6 +638,60 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
635638
&self.aliases
636639
}
637640

641+
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
642+
self.inner.state_fields(args)
643+
}
644+
645+
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
646+
self.inner.groups_accumulator_supported(args)
647+
}
648+
649+
fn create_groups_accumulator(
650+
&self,
651+
args: AccumulatorArgs,
652+
) -> Result<Box<dyn GroupsAccumulator>> {
653+
self.inner.create_groups_accumulator(args)
654+
}
655+
656+
fn create_sliding_accumulator(
657+
&self,
658+
args: AccumulatorArgs,
659+
) -> Result<Box<dyn Accumulator>> {
660+
self.inner.accumulator(args)
661+
}
662+
663+
fn with_beneficial_ordering(
664+
self: Arc<Self>,
665+
beneficial_ordering: bool,
666+
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
667+
Arc::clone(&self.inner)
668+
.with_beneficial_ordering(beneficial_ordering)
669+
.map(|udf| {
670+
udf.map(|udf| {
671+
Arc::new(AliasedAggregateUDFImpl {
672+
inner: udf,
673+
aliases: self.aliases.clone(),
674+
}) as Arc<dyn AggregateUDFImpl>
675+
})
676+
})
677+
}
678+
679+
fn order_sensitivity(&self) -> AggregateOrderSensitivity {
680+
self.inner.order_sensitivity()
681+
}
682+
683+
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
684+
self.inner.simplify()
685+
}
686+
687+
fn reverse_expr(&self) -> ReversedUDAF {
688+
self.inner.reverse_expr()
689+
}
690+
691+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
692+
self.inner.coerce_types(arg_types)
693+
}
694+
638695
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
639696
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
640697
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
@@ -649,6 +706,10 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
649706
self.aliases.hash(hasher);
650707
hasher.finish()
651708
}
709+
710+
fn is_descending(&self) -> Option<bool> {
711+
self.inner.is_descending()
712+
}
652713
}
653714

654715
/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers

datafusion/expr/src/udf.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ where
346346
/// let expr = add_one.call(vec![col("a")]);
347347
/// ```
348348
pub trait ScalarUDFImpl: Debug + Send + Sync {
349+
// Note: When adding any methods (with default implementations), remember to add them also
350+
// into the AliasedScalarUDFImpl below!
351+
349352
/// Returns this object as an [`Any`] trait object
350353
fn as_any(&self) -> &dyn Any;
351354

@@ -632,6 +635,14 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
632635
self.inner.name()
633636
}
634637

638+
fn display_name(&self, args: &[Expr]) -> Result<String> {
639+
self.inner.display_name(args)
640+
}
641+
642+
fn schema_name(&self, args: &[Expr]) -> Result<String> {
643+
self.inner.schema_name(args)
644+
}
645+
635646
fn signature(&self) -> &Signature {
636647
self.inner.signature()
637648
}
@@ -640,12 +651,57 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
640651
self.inner.return_type(arg_types)
641652
}
642653

654+
fn aliases(&self) -> &[String] {
655+
&self.aliases
656+
}
657+
658+
fn return_type_from_exprs(
659+
&self,
660+
args: &[Expr],
661+
schema: &dyn ExprSchema,
662+
arg_types: &[DataType],
663+
) -> Result<DataType> {
664+
self.inner.return_type_from_exprs(args, schema, arg_types)
665+
}
666+
643667
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
644668
self.inner.invoke(args)
645669
}
646670

647-
fn aliases(&self) -> &[String] {
648-
&self.aliases
671+
fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
672+
self.inner.invoke_no_args(number_rows)
673+
}
674+
675+
fn simplify(
676+
&self,
677+
args: Vec<Expr>,
678+
info: &dyn SimplifyInfo,
679+
) -> Result<ExprSimplifyResult> {
680+
self.inner.simplify(args, info)
681+
}
682+
683+
fn short_circuits(&self) -> bool {
684+
self.inner.short_circuits()
685+
}
686+
687+
fn evaluate_bounds(&self, input: &[&Interval]) -> Result<Interval> {
688+
self.inner.evaluate_bounds(input)
689+
}
690+
691+
fn propagate_constraints(
692+
&self,
693+
interval: &Interval,
694+
inputs: &[&Interval],
695+
) -> Result<Option<Vec<Interval>>> {
696+
self.inner.propagate_constraints(interval, inputs)
697+
}
698+
699+
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
700+
self.inner.output_ordering(inputs)
701+
}
702+
703+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
704+
self.inner.coerce_types(arg_types)
649705
}
650706

651707
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {

datafusion/expr/src/udwf.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ where
266266
/// .unwrap();
267267
/// ```
268268
pub trait WindowUDFImpl: Debug + Send + Sync {
269+
// Note: When adding any methods (with default implementations), remember to add them also
270+
// into the AliasedWindowUDFImpl below!
271+
269272
/// Returns this object as an [`Any`] trait object
270273
fn as_any(&self) -> &dyn Any;
271274

@@ -428,6 +431,10 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
428431
&self.aliases
429432
}
430433

434+
fn simplify(&self) -> Option<WindowFunctionSimplification> {
435+
self.inner.simplify()
436+
}
437+
431438
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
432439
if let Some(other) = other.as_any().downcast_ref::<AliasedWindowUDFImpl>() {
433440
self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases
@@ -442,6 +449,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
442449
self.aliases.hash(hasher);
443450
hasher.finish()
444451
}
452+
453+
fn nullable(&self) -> bool {
454+
self.inner.nullable()
455+
}
456+
457+
fn sort_options(&self) -> Option<SortOptions> {
458+
self.inner.sort_options()
459+
}
460+
461+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
462+
self.inner.coerce_types(arg_types)
463+
}
445464
}
446465

447466
/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers

0 commit comments

Comments
 (0)