Skip to content

Commit 4bd5a20

Browse files
committed
Improve unit test for aggregate udf with metadata
1 parent 1f9da9f commit 4bd5a20

File tree

1 file changed

+65
-18
lines changed

1 file changed

+65
-18
lines changed

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use std::sync::{
2828
};
2929

3030
use arrow::array::{
31-
types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, StringArray,
32-
StructArray, UInt64Array,
31+
record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray,
32+
StringArray, StructArray, UInt64Array,
3333
};
3434
use arrow::datatypes::{Fields, Schema};
3535
use datafusion::common::test_util::batches_to_string;
@@ -998,28 +998,75 @@ impl Accumulator for MetadataBasedAccumulator {
998998

999999
#[tokio::test]
10001000
async fn test_metadata_based_accumulator() -> Result<()> {
1001+
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
1002+
let schema = Arc::new(Schema::new(vec![
1003+
Field::new("no_metadata", DataType::UInt64, true),
1004+
Field::new("with_metadata", DataType::UInt64, true).with_metadata(
1005+
[("modify_values".to_string(), "double_output".to_string())]
1006+
.into_iter()
1007+
.collect(),
1008+
),
1009+
]));
1010+
1011+
let batch = RecordBatch::try_new(
1012+
schema,
1013+
vec![Arc::clone(&data_array), Arc::clone(&data_array)],
1014+
)?;
1015+
10011016
let ctx = SessionContext::new();
1002-
let arr = UInt64Array::from(vec![1]);
1003-
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(arr) as _)])?;
1004-
ctx.register_batch("t", batch).unwrap();
1017+
ctx.register_batch("t", batch)?;
1018+
let df = ctx.table("t").await?;
10051019

1006-
let udaf_no_metadata =
1020+
let no_output_meta_udf =
10071021
AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new()));
1008-
let udaf_with_metadata = AggregateUDF::from(MetadataBasedAggregateUdf::new(
1009-
HashMap::from_iter([("modify_values".to_string(), "double_output".to_string())]),
1022+
let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new(
1023+
[("output_metatype".to_string(), "custom_value".to_string())]
1024+
.into_iter()
1025+
.collect(),
10101026
));
1011-
ctx.register_udaf(udaf_no_metadata);
1012-
ctx.register_udaf(udaf_with_metadata);
10131027

1014-
let sql_df = ctx
1015-
.sql("SELECT metadata_based_udf_0(a) FROM t group by a")
1016-
.await?;
1017-
sql_df.show().await?;
1028+
let df = df.aggregate(
1029+
vec![],
1030+
vec![
1031+
no_output_meta_udf
1032+
.call(vec![col("no_metadata")])
1033+
.alias("meta_no_in_no_out"),
1034+
no_output_meta_udf
1035+
.call(vec![col("with_metadata")])
1036+
.alias("meta_with_in_no_out"),
1037+
with_output_meta_udf
1038+
.call(vec![col("no_metadata")])
1039+
.alias("meta_no_in_with_out"),
1040+
with_output_meta_udf
1041+
.call(vec![col("with_metadata")])
1042+
.alias("meta_with_in_with_out"),
1043+
],
1044+
)?;
10181045

1019-
let sql_df = ctx
1020-
.sql("SELECT metadata_based_udf_1(a) FROM t group by a")
1021-
.await?;
1022-
sql_df.show().await?;
1046+
let actual = df.collect().await?;
1047+
1048+
// To test for output metadata handling, we set the expected values on the result
1049+
// To test for input metadata handling, we check the numbers returned
1050+
let mut output_meta = HashMap::new();
1051+
let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string());
1052+
let expected_schema = Schema::new(vec![
1053+
Field::new("meta_no_in_no_out", DataType::UInt64, true),
1054+
Field::new("meta_with_in_no_out", DataType::UInt64, true),
1055+
Field::new("meta_no_in_with_out", DataType::UInt64, true)
1056+
.with_metadata(output_meta.clone()),
1057+
Field::new("meta_with_in_with_out", DataType::UInt64, true)
1058+
.with_metadata(output_meta.clone()),
1059+
]);
1060+
1061+
let expected = record_batch!(
1062+
("meta_no_in_no_out", UInt64, [50]),
1063+
("meta_with_in_no_out", UInt64, [100]),
1064+
("meta_no_in_with_out", UInt64, [50]),
1065+
("meta_with_in_with_out", UInt64, [100])
1066+
)?
1067+
.with_schema(Arc::new(expected_schema))?;
1068+
1069+
assert_eq!(expected, actual[0]);
10231070

10241071
Ok(())
10251072
}

0 commit comments

Comments
 (0)