@@ -28,8 +28,8 @@ use std::sync::{
28
28
} ;
29
29
30
30
use arrow:: array:: {
31
- types:: UInt64Type , Array , AsArray , Int32Array , PrimitiveArray , StringArray ,
32
- StructArray , UInt64Array ,
31
+ record_batch , types:: UInt64Type , Array , AsArray , Int32Array , PrimitiveArray ,
32
+ StringArray , StructArray , UInt64Array ,
33
33
} ;
34
34
use arrow:: datatypes:: { Fields , Schema } ;
35
35
use datafusion:: common:: test_util:: batches_to_string;
@@ -998,28 +998,75 @@ impl Accumulator for MetadataBasedAccumulator {
998
998
999
999
#[ tokio:: test]
1000
1000
async fn test_metadata_based_accumulator ( ) -> Result < ( ) > {
1001
+ let data_array = Arc :: new ( UInt64Array :: from ( vec ! [ 0 , 5 , 10 , 15 , 20 ] ) ) as ArrayRef ;
1002
+ let schema = Arc :: new ( Schema :: new ( vec ! [
1003
+ Field :: new( "no_metadata" , DataType :: UInt64 , true ) ,
1004
+ Field :: new( "with_metadata" , DataType :: UInt64 , true ) . with_metadata(
1005
+ [ ( "modify_values" . to_string( ) , "double_output" . to_string( ) ) ]
1006
+ . into_iter( )
1007
+ . collect( ) ,
1008
+ ) ,
1009
+ ] ) ) ;
1010
+
1011
+ let batch = RecordBatch :: try_new (
1012
+ schema,
1013
+ vec ! [ Arc :: clone( & data_array) , Arc :: clone( & data_array) ] ,
1014
+ ) ?;
1015
+
1001
1016
let ctx = SessionContext :: new ( ) ;
1002
- let arr = UInt64Array :: from ( vec ! [ 1 ] ) ;
1003
- let batch = RecordBatch :: try_from_iter ( vec ! [ ( "a" , Arc :: new( arr) as _) ] ) ?;
1004
- ctx. register_batch ( "t" , batch) . unwrap ( ) ;
1017
+ ctx. register_batch ( "t" , batch) ?;
1018
+ let df = ctx. table ( "t" ) . await ?;
1005
1019
1006
- let udaf_no_metadata =
1020
+ let no_output_meta_udf =
1007
1021
AggregateUDF :: from ( MetadataBasedAggregateUdf :: new ( HashMap :: new ( ) ) ) ;
1008
- let udaf_with_metadata = AggregateUDF :: from ( MetadataBasedAggregateUdf :: new (
1009
- HashMap :: from_iter ( [ ( "modify_values" . to_string ( ) , "double_output" . to_string ( ) ) ] ) ,
1022
+ let with_output_meta_udf = AggregateUDF :: from ( MetadataBasedAggregateUdf :: new (
1023
+ [ ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ]
1024
+ . into_iter ( )
1025
+ . collect ( ) ,
1010
1026
) ) ;
1011
- ctx. register_udaf ( udaf_no_metadata) ;
1012
- ctx. register_udaf ( udaf_with_metadata) ;
1013
1027
1014
- let sql_df = ctx
1015
- . sql ( "SELECT metadata_based_udf_0(a) FROM t group by a" )
1016
- . await ?;
1017
- sql_df. show ( ) . await ?;
1028
+ let df = df. aggregate (
1029
+ vec ! [ ] ,
1030
+ vec ! [
1031
+ no_output_meta_udf
1032
+ . call( vec![ col( "no_metadata" ) ] )
1033
+ . alias( "meta_no_in_no_out" ) ,
1034
+ no_output_meta_udf
1035
+ . call( vec![ col( "with_metadata" ) ] )
1036
+ . alias( "meta_with_in_no_out" ) ,
1037
+ with_output_meta_udf
1038
+ . call( vec![ col( "no_metadata" ) ] )
1039
+ . alias( "meta_no_in_with_out" ) ,
1040
+ with_output_meta_udf
1041
+ . call( vec![ col( "with_metadata" ) ] )
1042
+ . alias( "meta_with_in_with_out" ) ,
1043
+ ] ,
1044
+ ) ?;
1018
1045
1019
- let sql_df = ctx
1020
- . sql ( "SELECT metadata_based_udf_1(a) FROM t group by a" )
1021
- . await ?;
1022
- sql_df. show ( ) . await ?;
1046
+ let actual = df. collect ( ) . await ?;
1047
+
1048
+ // To test for output metadata handling, we set the expected values on the result
1049
+ // To test for input metadata handling, we check the numbers returned
1050
+ let mut output_meta = HashMap :: new ( ) ;
1051
+ let _ = output_meta. insert ( "output_metatype" . to_string ( ) , "custom_value" . to_string ( ) ) ;
1052
+ let expected_schema = Schema :: new ( vec ! [
1053
+ Field :: new( "meta_no_in_no_out" , DataType :: UInt64 , true ) ,
1054
+ Field :: new( "meta_with_in_no_out" , DataType :: UInt64 , true ) ,
1055
+ Field :: new( "meta_no_in_with_out" , DataType :: UInt64 , true )
1056
+ . with_metadata( output_meta. clone( ) ) ,
1057
+ Field :: new( "meta_with_in_with_out" , DataType :: UInt64 , true )
1058
+ . with_metadata( output_meta. clone( ) ) ,
1059
+ ] ) ;
1060
+
1061
+ let expected = record_batch ! (
1062
+ ( "meta_no_in_no_out" , UInt64 , [ 50 ] ) ,
1063
+ ( "meta_with_in_no_out" , UInt64 , [ 100 ] ) ,
1064
+ ( "meta_no_in_with_out" , UInt64 , [ 50 ] ) ,
1065
+ ( "meta_with_in_with_out" , UInt64 , [ 100 ] )
1066
+ ) ?
1067
+ . with_schema ( Arc :: new ( expected_schema) ) ?;
1068
+
1069
+ assert_eq ! ( expected, actual[ 0 ] ) ;
1023
1070
1024
1071
Ok ( ( ) )
1025
1072
}
0 commit comments