18
18
//! This module contains end to end demonstrations of creating
19
19
//! user defined aggregate functions
20
20
21
+ use std:: any:: Any ;
22
+ use std:: collections:: HashMap ;
21
23
use std:: hash:: { DefaultHasher , Hash , Hasher } ;
22
24
use std:: mem:: { size_of, size_of_val} ;
23
25
use std:: sync:: {
@@ -26,10 +28,10 @@ use std::sync::{
26
28
} ;
27
29
28
30
use arrow:: array:: {
29
- types:: UInt64Type , AsArray , Int32Array , PrimitiveArray , StringArray , StructArray ,
31
+ record_batch, types:: UInt64Type , Array , AsArray , Int32Array , PrimitiveArray ,
32
+ StringArray , StructArray , UInt64Array ,
30
33
} ;
31
34
use arrow:: datatypes:: { Fields , Schema } ;
32
-
33
35
use datafusion:: common:: test_util:: batches_to_string;
34
36
use datafusion:: dataframe:: DataFrame ;
35
37
use datafusion:: datasource:: MemTable ;
@@ -48,11 +50,12 @@ use datafusion::{
48
50
prelude:: SessionContext ,
49
51
scalar:: ScalarValue ,
50
52
} ;
51
- use datafusion_common:: assert_contains;
53
+ use datafusion_common:: { assert_contains, exec_datafusion_err } ;
52
54
use datafusion_common:: { cast:: as_primitive_array, exec_err} ;
55
+ use datafusion_expr:: expr:: WindowFunction ;
53
56
use datafusion_expr:: {
54
- col, create_udaf, function:: AccumulatorArgs , AggregateUDFImpl , GroupsAccumulator ,
55
- LogicalPlanBuilder , SimpleAggregateUDF ,
57
+ col, create_udaf, function:: AccumulatorArgs , AggregateUDFImpl , Expr ,
58
+ GroupsAccumulator , LogicalPlanBuilder , SimpleAggregateUDF , WindowFunctionDefinition ,
56
59
} ;
57
60
use datafusion_functions_aggregate:: average:: AvgAccumulator ;
58
61
@@ -781,7 +784,7 @@ struct TestGroupsAccumulator {
781
784
}
782
785
783
786
impl AggregateUDFImpl for TestGroupsAccumulator {
784
- fn as_any ( & self ) -> & dyn std :: any :: Any {
787
+ fn as_any ( & self ) -> & dyn Any {
785
788
self
786
789
}
787
790
@@ -890,3 +893,263 @@ impl GroupsAccumulator for TestGroupsAccumulator {
890
893
size_of :: < u64 > ( )
891
894
}
892
895
}
896
+
897
+ #[ derive( Debug ) ]
898
+ struct MetadataBasedAggregateUdf {
899
+ name : String ,
900
+ signature : Signature ,
901
+ metadata : HashMap < String , String > ,
902
+ }
903
+
904
+ impl MetadataBasedAggregateUdf {
905
+ fn new ( metadata : HashMap < String , String > ) -> Self {
906
+ // The name we return must be unique. Otherwise we will not call distinct
907
+ // instances of this UDF. This is a small hack for the unit tests to get unique
908
+ // names, but you could do something more elegant with the metadata.
909
+ let name = format ! ( "metadata_based_udf_{}" , metadata. len( ) ) ;
910
+ Self {
911
+ name,
912
+ signature : Signature :: exact ( vec ! [ DataType :: UInt64 ] , Volatility :: Immutable ) ,
913
+ metadata,
914
+ }
915
+ }
916
+ }
917
+
918
+ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
919
+ fn as_any ( & self ) -> & dyn Any {
920
+ self
921
+ }
922
+
923
+ fn name ( & self ) -> & str {
924
+ & self . name
925
+ }
926
+
927
+ fn signature ( & self ) -> & Signature {
928
+ & self . signature
929
+ }
930
+
931
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
932
+ unimplemented ! ( "this should never be called since return_field is implemented" ) ;
933
+ }
934
+
935
+ fn return_field ( & self , _arg_fields : & [ Field ] ) -> Result < Field > {
936
+ Ok ( Field :: new ( self . name ( ) , DataType :: UInt64 , true )
937
+ . with_metadata ( self . metadata . clone ( ) ) )
938
+ }
939
+
940
+ fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
941
+ let input_expr = acc_args
942
+ . exprs
943
+ . first ( )
944
+ . ok_or ( exec_datafusion_err ! ( "Expected one argument" ) ) ?;
945
+ let input_field = input_expr. return_field ( acc_args. schema ) ?;
946
+
947
+ let double_output = input_field
948
+ . metadata ( )
949
+ . get ( "modify_values" )
950
+ . map ( |v| v == "double_output" )
951
+ . unwrap_or ( false ) ;
952
+
953
+ Ok ( Box :: new ( MetadataBasedAccumulator {
954
+ double_output,
955
+ curr_sum : 0 ,
956
+ } ) )
957
+ }
958
+ }
959
+
960
+ #[ derive( Debug ) ]
961
+ struct MetadataBasedAccumulator {
962
+ double_output : bool ,
963
+ curr_sum : u64 ,
964
+ }
965
+
966
+ impl Accumulator for MetadataBasedAccumulator {
967
+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
968
+ let arr = values[ 0 ]
969
+ . as_any ( )
970
+ . downcast_ref :: < UInt64Array > ( )
971
+ . ok_or ( exec_datafusion_err ! ( "Expected UInt64Array" ) ) ?;
972
+
973
+ self . curr_sum = arr. iter ( ) . fold ( self . curr_sum , |a, b| a + b. unwrap_or ( 0 ) ) ;
974
+
975
+ Ok ( ( ) )
976
+ }
977
+
978
+ fn evaluate ( & mut self ) -> Result < ScalarValue > {
979
+ let v = match self . double_output {
980
+ true => self . curr_sum * 2 ,
981
+ false => self . curr_sum ,
982
+ } ;
983
+
984
+ Ok ( ScalarValue :: from ( v) )
985
+ }
986
+
987
+ fn size ( & self ) -> usize {
988
+ 9
989
+ }
990
+
991
+ fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
992
+ Ok ( vec ! [ ScalarValue :: from( self . curr_sum) ] )
993
+ }
994
+
995
+ fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
996
+ self . update_batch ( states)
997
+ }
998
+ }
999
+
1000
+ #[ tokio:: test]
1001
+ async fn test_metadata_based_aggregate ( ) -> Result < ( ) > {
1002
+ let data_array = Arc :: new ( UInt64Array :: from ( vec ! [ 0 , 5 , 10 , 15 , 20 ] ) ) as ArrayRef ;
1003
+ let schema = Arc :: new ( Schema :: new ( vec ! [
1004
+ Field :: new( "no_metadata" , DataType :: UInt64 , true ) ,
1005
+ Field :: new( "with_metadata" , DataType :: UInt64 , true ) . with_metadata(
1006
+ [ ( "modify_values" . to_string( ) , "double_output" . to_string( ) ) ]
1007
+ . into_iter( )
1008
+ . collect( ) ,
1009
+ ) ,
1010
+ ] ) ) ;
1011
+
1012
+ let batch = RecordBatch :: try_new (
1013
+ schema,
1014
+ vec ! [ Arc :: clone( & data_array) , Arc :: clone( & data_array) ] ,
1015
+ ) ?;
1016
+
1017
+ let ctx = SessionContext :: new ( ) ;
1018
+ ctx. register_batch ( "t" , batch) ?;
1019
+ let df = ctx. table ( "t" ) . await ?;
1020
+
1021
+ let no_output_meta_udf =
1022
+ AggregateUDF :: from ( MetadataBasedAggregateUdf :: new ( HashMap :: new ( ) ) ) ;
1023
+ let with_output_meta_udf = AggregateUDF :: from ( MetadataBasedAggregateUdf :: new (
1024
+ [ ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ]
1025
+ . into_iter ( )
1026
+ . collect ( ) ,
1027
+ ) ) ;
1028
+
1029
+ let df = df. aggregate (
1030
+ vec ! [ ] ,
1031
+ vec ! [
1032
+ no_output_meta_udf
1033
+ . call( vec![ col( "no_metadata" ) ] )
1034
+ . alias( "meta_no_in_no_out" ) ,
1035
+ no_output_meta_udf
1036
+ . call( vec![ col( "with_metadata" ) ] )
1037
+ . alias( "meta_with_in_no_out" ) ,
1038
+ with_output_meta_udf
1039
+ . call( vec![ col( "no_metadata" ) ] )
1040
+ . alias( "meta_no_in_with_out" ) ,
1041
+ with_output_meta_udf
1042
+ . call( vec![ col( "with_metadata" ) ] )
1043
+ . alias( "meta_with_in_with_out" ) ,
1044
+ ] ,
1045
+ ) ?;
1046
+
1047
+ let actual = df. collect ( ) . await ?;
1048
+
1049
+ // To test for output metadata handling, we set the expected values on the result
1050
+ // To test for input metadata handling, we check the numbers returned
1051
+ let mut output_meta = HashMap :: new ( ) ;
1052
+ let _ = output_meta. insert ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ;
1053
+ let expected_schema = Schema :: new ( vec ! [
1054
+ Field :: new( "meta_no_in_no_out" , DataType :: UInt64 , true ) ,
1055
+ Field :: new( "meta_with_in_no_out" , DataType :: UInt64 , true ) ,
1056
+ Field :: new( "meta_no_in_with_out" , DataType :: UInt64 , true )
1057
+ . with_metadata( output_meta. clone( ) ) ,
1058
+ Field :: new( "meta_with_in_with_out" , DataType :: UInt64 , true )
1059
+ . with_metadata( output_meta. clone( ) ) ,
1060
+ ] ) ;
1061
+
1062
+ let expected = record_batch ! (
1063
+ ( "meta_no_in_no_out" , UInt64 , [ 50 ] ) ,
1064
+ ( "meta_with_in_no_out" , UInt64 , [ 100 ] ) ,
1065
+ ( "meta_no_in_with_out" , UInt64 , [ 50 ] ) ,
1066
+ ( "meta_with_in_with_out" , UInt64 , [ 100 ] )
1067
+ ) ?
1068
+ . with_schema ( Arc :: new ( expected_schema) ) ?;
1069
+
1070
+ assert_eq ! ( expected, actual[ 0 ] ) ;
1071
+
1072
+ Ok ( ( ) )
1073
+ }
1074
+
1075
+ #[ tokio:: test]
1076
+ async fn test_metadata_based_aggregate_as_window ( ) -> Result < ( ) > {
1077
+ let data_array = Arc :: new ( UInt64Array :: from ( vec ! [ 0 , 5 , 10 , 15 , 20 ] ) ) as ArrayRef ;
1078
+ let schema = Arc :: new ( Schema :: new ( vec ! [
1079
+ Field :: new( "no_metadata" , DataType :: UInt64 , true ) ,
1080
+ Field :: new( "with_metadata" , DataType :: UInt64 , true ) . with_metadata(
1081
+ [ ( "modify_values" . to_string( ) , "double_output" . to_string( ) ) ]
1082
+ . into_iter( )
1083
+ . collect( ) ,
1084
+ ) ,
1085
+ ] ) ) ;
1086
+
1087
+ let batch = RecordBatch :: try_new (
1088
+ schema,
1089
+ vec ! [ Arc :: clone( & data_array) , Arc :: clone( & data_array) ] ,
1090
+ ) ?;
1091
+
1092
+ let ctx = SessionContext :: new ( ) ;
1093
+ ctx. register_batch ( "t" , batch) ?;
1094
+ let df = ctx. table ( "t" ) . await ?;
1095
+
1096
+ let no_output_meta_udf = Arc :: new ( AggregateUDF :: from (
1097
+ MetadataBasedAggregateUdf :: new ( HashMap :: new ( ) ) ,
1098
+ ) ) ;
1099
+ let with_output_meta_udf =
1100
+ Arc :: new ( AggregateUDF :: from ( MetadataBasedAggregateUdf :: new (
1101
+ [ ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ]
1102
+ . into_iter ( )
1103
+ . collect ( ) ,
1104
+ ) ) ) ;
1105
+
1106
+ let df = df. select ( vec ! [
1107
+ Expr :: WindowFunction ( WindowFunction :: new(
1108
+ WindowFunctionDefinition :: AggregateUDF ( Arc :: clone( & no_output_meta_udf) ) ,
1109
+ vec![ col( "no_metadata" ) ] ,
1110
+ ) )
1111
+ . alias( "meta_no_in_no_out" ) ,
1112
+ Expr :: WindowFunction ( WindowFunction :: new(
1113
+ WindowFunctionDefinition :: AggregateUDF ( no_output_meta_udf) ,
1114
+ vec![ col( "with_metadata" ) ] ,
1115
+ ) )
1116
+ . alias( "meta_with_in_no_out" ) ,
1117
+ Expr :: WindowFunction ( WindowFunction :: new(
1118
+ WindowFunctionDefinition :: AggregateUDF ( Arc :: clone( & with_output_meta_udf) ) ,
1119
+ vec![ col( "no_metadata" ) ] ,
1120
+ ) )
1121
+ . alias( "meta_no_in_with_out" ) ,
1122
+ Expr :: WindowFunction ( WindowFunction :: new(
1123
+ WindowFunctionDefinition :: AggregateUDF ( with_output_meta_udf) ,
1124
+ vec![ col( "with_metadata" ) ] ,
1125
+ ) )
1126
+ . alias( "meta_with_in_with_out" ) ,
1127
+ ] ) ?;
1128
+
1129
+ let actual = df. collect ( ) . await ?;
1130
+
1131
+ // To test for output metadata handling, we set the expected values on the result
1132
+ // To test for input metadata handling, we check the numbers returned
1133
+ let mut output_meta = HashMap :: new ( ) ;
1134
+ let _ = output_meta. insert ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ;
1135
+ let expected_schema = Schema :: new ( vec ! [
1136
+ Field :: new( "meta_no_in_no_out" , DataType :: UInt64 , true ) ,
1137
+ Field :: new( "meta_with_in_no_out" , DataType :: UInt64 , true ) ,
1138
+ Field :: new( "meta_no_in_with_out" , DataType :: UInt64 , true )
1139
+ . with_metadata( output_meta. clone( ) ) ,
1140
+ Field :: new( "meta_with_in_with_out" , DataType :: UInt64 , true )
1141
+ . with_metadata( output_meta. clone( ) ) ,
1142
+ ] ) ;
1143
+
1144
+ let expected = record_batch ! (
1145
+ ( "meta_no_in_no_out" , UInt64 , [ 50 , 50 , 50 , 50 , 50 ] ) ,
1146
+ ( "meta_with_in_no_out" , UInt64 , [ 100 , 100 , 100 , 100 , 100 ] ) ,
1147
+ ( "meta_no_in_with_out" , UInt64 , [ 50 , 50 , 50 , 50 , 50 ] ) ,
1148
+ ( "meta_with_in_with_out" , UInt64 , [ 100 , 100 , 100 , 100 , 100 ] )
1149
+ ) ?
1150
+ . with_schema ( Arc :: new ( expected_schema) ) ?;
1151
+
1152
+ assert_eq ! ( expected, actual[ 0 ] ) ;
1153
+
1154
+ Ok ( ( ) )
1155
+ }
0 commit comments