Skip to content

Commit 389f7f7

Browse files
ngli-meM
and
M
authored
Add PartialOrd for the DF subfields/structs for the WindowFunction expr (#12421)
* Added PartialOrd implementations for AggregateUDF, AggregateUDFImpl, BuiltInWindowFunction and WindowUDF. * Added tests for PartialOrd in udwf.rs. * Removed manual implementation of PartialOrd for TypeSignature, replaced with derives. * Adjusted the assertion for clarity on comparing enum variants. * Edited assertions to use partial_cmp for clarity, and reformatted with rustfmt. --------- Co-authored-by: M <[email protected]>
1 parent 6bf3479 commit 389f7f7

File tree

5 files changed

+264
-9
lines changed

5 files changed

+264
-9
lines changed

datafusion/expr-common/src/signature.rs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ pub enum Volatility {
8484
/// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())),
8585
/// ]);
8686
/// ```
87-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
87+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
8888
pub enum TypeSignature {
8989
/// One or more arguments of an common type out of a list of valid types.
9090
///
@@ -127,7 +127,7 @@ pub enum TypeSignature {
127127
Numeric(usize),
128128
}
129129

130-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
130+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
131131
pub enum ArrayFunctionSignature {
132132
/// Specialized Signature for ArrayAppend and similar functions
133133
/// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list.
@@ -241,7 +241,7 @@ impl TypeSignature {
241241
///
242242
/// DataFusion will automatically coerce (cast) argument types to one of the supported
243243
/// function signatures, if possible.
244-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
244+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
245245
pub struct Signature {
246246
/// The data types that the function accepts. See [TypeSignature] for more information.
247247
pub type_signature: TypeSignature,
@@ -418,4 +418,24 @@ mod tests {
418418
);
419419
}
420420
}
421+
422+
#[test]
423+
fn type_signature_partial_ord() {
424+
// Test validates that partial ord is defined for TypeSignature and Signature.
425+
assert!(TypeSignature::UserDefined < TypeSignature::VariadicAny);
426+
assert!(TypeSignature::UserDefined < TypeSignature::Any(1));
427+
428+
assert!(
429+
TypeSignature::Uniform(1, vec![DataType::Null])
430+
< TypeSignature::Uniform(1, vec![DataType::Boolean])
431+
);
432+
assert!(
433+
TypeSignature::Uniform(1, vec![DataType::Null])
434+
< TypeSignature::Uniform(2, vec![DataType::Null])
435+
);
436+
assert!(
437+
TypeSignature::Uniform(usize::MAX, vec![DataType::Null])
438+
< TypeSignature::Exact(vec![DataType::Null])
439+
);
440+
}
421441
}

datafusion/expr/src/built_in_window_function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ impl fmt::Display for BuiltInWindowFunction {
3838
/// A [window function] built in to DataFusion
3939
///
4040
/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL)
41-
#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)]
41+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
4242
pub enum BuiltInWindowFunction {
4343
/// rank of the current row with gaps; same as row_number of its first peer
4444
Rank,

datafusion/expr/src/expr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ impl AggregateFunction {
688688
}
689689

690690
/// WindowFunction
691-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
691+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
692692
/// Defines which implementation of an aggregate function DataFusion should call.
693693
pub enum WindowFunctionDefinition {
694694
/// A built in aggregate function that leverages an aggregate function

datafusion/expr/src/udaf.rs

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! [`AggregateUDF`]: User Defined Aggregate Functions
1919
2020
use std::any::Any;
21+
use std::cmp::Ordering;
2122
use std::fmt::{self, Debug, Formatter};
2223
use std::hash::{DefaultHasher, Hash, Hasher};
2324
use std::sync::Arc;
@@ -68,7 +69,7 @@ use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
6869
/// [`create_udaf`]: crate::expr_fn::create_udaf
6970
/// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
7071
/// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
71-
#[derive(Debug, Clone)]
72+
#[derive(Debug, Clone, PartialOrd)]
7273
pub struct AggregateUDF {
7374
inner: Arc<dyn AggregateUDFImpl>,
7475
}
@@ -584,6 +585,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
584585
}
585586
}
586587

588+
impl PartialEq for dyn AggregateUDFImpl {
589+
fn eq(&self, other: &Self) -> bool {
590+
self.equals(other)
591+
}
592+
}
593+
594+
// manual implementation of `PartialOrd`
595+
// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl
596+
// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5
597+
impl PartialOrd for dyn AggregateUDFImpl {
598+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
599+
match self.name().partial_cmp(other.name()) {
600+
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
601+
cmp => cmp,
602+
}
603+
}
604+
}
605+
587606
pub enum ReversedUDAF {
588607
/// The expression is the same as the original expression, like SUM, COUNT
589608
Identical,
@@ -758,3 +777,111 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
758777
(self.accumulator)(acc_args)
759778
}
760779
}
780+
781+
#[cfg(test)]
782+
mod test {
783+
use crate::{AggregateUDF, AggregateUDFImpl};
784+
use arrow::datatypes::{DataType, Field};
785+
use datafusion_common::Result;
786+
use datafusion_expr_common::accumulator::Accumulator;
787+
use datafusion_expr_common::signature::{Signature, Volatility};
788+
use datafusion_functions_aggregate_common::accumulator::{
789+
AccumulatorArgs, StateFieldsArgs,
790+
};
791+
use std::any::Any;
792+
use std::cmp::Ordering;
793+
794+
#[derive(Debug, Clone)]
795+
struct AMeanUdf {
796+
signature: Signature,
797+
}
798+
799+
impl AMeanUdf {
800+
fn new() -> Self {
801+
Self {
802+
signature: Signature::uniform(
803+
1,
804+
vec![DataType::Float64],
805+
Volatility::Immutable,
806+
),
807+
}
808+
}
809+
}
810+
811+
impl AggregateUDFImpl for AMeanUdf {
812+
fn as_any(&self) -> &dyn Any {
813+
self
814+
}
815+
fn name(&self) -> &str {
816+
"a"
817+
}
818+
fn signature(&self) -> &Signature {
819+
&self.signature
820+
}
821+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
822+
unimplemented!()
823+
}
824+
fn accumulator(
825+
&self,
826+
_acc_args: AccumulatorArgs,
827+
) -> Result<Box<dyn Accumulator>> {
828+
unimplemented!()
829+
}
830+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
831+
unimplemented!()
832+
}
833+
}
834+
835+
#[derive(Debug, Clone)]
836+
struct BMeanUdf {
837+
signature: Signature,
838+
}
839+
impl BMeanUdf {
840+
fn new() -> Self {
841+
Self {
842+
signature: Signature::uniform(
843+
1,
844+
vec![DataType::Float64],
845+
Volatility::Immutable,
846+
),
847+
}
848+
}
849+
}
850+
851+
impl AggregateUDFImpl for BMeanUdf {
852+
fn as_any(&self) -> &dyn Any {
853+
self
854+
}
855+
fn name(&self) -> &str {
856+
"b"
857+
}
858+
fn signature(&self) -> &Signature {
859+
&self.signature
860+
}
861+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
862+
unimplemented!()
863+
}
864+
fn accumulator(
865+
&self,
866+
_acc_args: AccumulatorArgs,
867+
) -> Result<Box<dyn Accumulator>> {
868+
unimplemented!()
869+
}
870+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
871+
unimplemented!()
872+
}
873+
}
874+
875+
#[test]
876+
fn test_partial_ord() {
877+
// Test validates that partial ord is defined for AggregateUDF using the name and signature,
878+
// not intended to exhaustively test all possibilities
879+
let a1 = AggregateUDF::from(AMeanUdf::new());
880+
let a2 = AggregateUDF::from(AMeanUdf::new());
881+
assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
882+
883+
let b1 = AggregateUDF::from(BMeanUdf::new());
884+
assert!(a1 < b1);
885+
assert!(!(a1 == b1));
886+
}
887+
}

datafusion/expr/src/udwf.rs

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@
1818
//! [`WindowUDF`]: User Defined Window Functions
1919
2020
use arrow::compute::SortOptions;
21+
use arrow::datatypes::DataType;
22+
use std::cmp::Ordering;
2123
use std::hash::{DefaultHasher, Hash, Hasher};
2224
use std::{
2325
any::Any,
2426
fmt::{self, Debug, Display, Formatter},
2527
sync::Arc,
2628
};
2729

28-
use arrow::datatypes::DataType;
29-
3030
use datafusion_common::{not_impl_err, Result};
3131

3232
use crate::expr::WindowFunction;
@@ -54,7 +54,7 @@ use crate::{
5454
/// [`create_udwf`]: crate::expr_fn::create_udwf
5555
/// [`simple_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs
5656
/// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs
57-
#[derive(Debug, Clone)]
57+
#[derive(Debug, Clone, PartialOrd)]
5858
pub struct WindowUDF {
5959
inner: Arc<dyn WindowUDFImpl>,
6060
}
@@ -386,6 +386,21 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
386386
}
387387
}
388388

389+
impl PartialEq for dyn WindowUDFImpl {
390+
fn eq(&self, other: &Self) -> bool {
391+
self.equals(other)
392+
}
393+
}
394+
395+
impl PartialOrd for dyn WindowUDFImpl {
396+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
397+
match self.name().partial_cmp(other.name()) {
398+
Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()),
399+
cmp => cmp,
400+
}
401+
}
402+
}
403+
389404
/// WindowUDF that adds an alias to the underlying function. It is better to
390405
/// implement [`WindowUDFImpl`], which supports aliases, directly if possible.
391406
#[derive(Debug)]
@@ -511,3 +526,96 @@ impl WindowUDFImpl for WindowUDFLegacyWrapper {
511526
(self.partition_evaluator_factory)()
512527
}
513528
}
529+
530+
#[cfg(test)]
531+
mod test {
532+
use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl};
533+
use arrow::datatypes::DataType;
534+
use datafusion_common::Result;
535+
use datafusion_expr_common::signature::{Signature, Volatility};
536+
use std::any::Any;
537+
use std::cmp::Ordering;
538+
539+
#[derive(Debug, Clone)]
540+
struct AWindowUDF {
541+
signature: Signature,
542+
}
543+
544+
impl AWindowUDF {
545+
fn new() -> Self {
546+
Self {
547+
signature: Signature::uniform(
548+
1,
549+
vec![DataType::Int32],
550+
Volatility::Immutable,
551+
),
552+
}
553+
}
554+
}
555+
556+
/// Implement the WindowUDFImpl trait for AddOne
557+
impl WindowUDFImpl for AWindowUDF {
558+
fn as_any(&self) -> &dyn Any {
559+
self
560+
}
561+
fn name(&self) -> &str {
562+
"a"
563+
}
564+
fn signature(&self) -> &Signature {
565+
&self.signature
566+
}
567+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
568+
unimplemented!()
569+
}
570+
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
571+
unimplemented!()
572+
}
573+
}
574+
575+
#[derive(Debug, Clone)]
576+
struct BWindowUDF {
577+
signature: Signature,
578+
}
579+
580+
impl BWindowUDF {
581+
fn new() -> Self {
582+
Self {
583+
signature: Signature::uniform(
584+
1,
585+
vec![DataType::Int32],
586+
Volatility::Immutable,
587+
),
588+
}
589+
}
590+
}
591+
592+
/// Implement the WindowUDFImpl trait for AddOne
593+
impl WindowUDFImpl for BWindowUDF {
594+
fn as_any(&self) -> &dyn Any {
595+
self
596+
}
597+
fn name(&self) -> &str {
598+
"b"
599+
}
600+
fn signature(&self) -> &Signature {
601+
&self.signature
602+
}
603+
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
604+
unimplemented!()
605+
}
606+
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
607+
unimplemented!()
608+
}
609+
}
610+
611+
#[test]
612+
fn test_partial_ord() {
613+
let a1 = WindowUDF::from(AWindowUDF::new());
614+
let a2 = WindowUDF::from(AWindowUDF::new());
615+
assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal));
616+
617+
let b1 = WindowUDF::from(BWindowUDF::new());
618+
assert!(a1 < b1);
619+
assert!(!(a1 == b1));
620+
}
621+
}

0 commit comments

Comments
 (0)