Skip to content

Commit 577c424

Browse files
timsauceralamb
andauthored
feat: metadata handling for aggregates and window functions (#15911)
* Move expr_schema to use return_field instead of return_type * More work on moving to Field from DataType for aggregates * Update field output name for aggregates * Improve unit test for aggregate udf with metadata * Move window functions over to use Field instead of DataType * Correct nullability flag * Add import after rebase * Add unit test for using udaf as window function with metadata processing * Update documentation for migration guide * Update naming from data type to field to match the actual parameters passed * Avoid some allocations * Update docs to use aggregate example --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent ca46932 commit 577c424

File tree

36 files changed

+838
-255
lines changed

36 files changed

+838
-255
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl AggregateUDFImpl for GeoMeanUdaf {
9494
/// This is the description of the state. accumulator's state() must match the types here.
9595
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
9696
Ok(vec![
97-
Field::new("prod", args.return_type.clone(), true),
97+
Field::new("prod", args.return_field.data_type().clone(), true),
9898
Field::new("n", DataType::UInt32, true),
9999
])
100100
}

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ use datafusion::prelude::{SessionConfig, SessionContext};
3535
use datafusion_common::HashMap;
3636
use datafusion_common::{Result, ScalarValue};
3737
use datafusion_common_runtime::SpawnedTask;
38-
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
38+
use datafusion_expr::type_coercion::functions::fields_with_aggregate_udf;
3939
use datafusion_expr::{
4040
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
4141
};
@@ -448,9 +448,9 @@ fn get_random_function(
448448
if !args.is_empty() {
449449
// Do type coercion first argument
450450
let a = args[0].clone();
451-
let dt = a.data_type(schema.as_ref()).unwrap();
452-
let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap();
453-
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
451+
let dt = a.return_field(schema.as_ref()).unwrap();
452+
let coerced = fields_with_aggregate_udf(&[dt], udf).unwrap();
453+
args[0] = cast(a, schema, coerced[0].data_type().clone()).unwrap();
454454
}
455455
}
456456

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 269 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
//! This module contains end to end demonstrations of creating
1919
//! user defined aggregate functions
2020
21+
use std::any::Any;
22+
use std::collections::HashMap;
2123
use std::hash::{DefaultHasher, Hash, Hasher};
2224
use std::mem::{size_of, size_of_val};
2325
use std::sync::{
@@ -26,10 +28,10 @@ use std::sync::{
2628
};
2729

2830
use arrow::array::{
29-
types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray,
31+
record_batch, types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray,
32+
StringArray, StructArray, UInt64Array,
3033
};
3134
use arrow::datatypes::{Fields, Schema};
32-
3335
use datafusion::common::test_util::batches_to_string;
3436
use datafusion::dataframe::DataFrame;
3537
use datafusion::datasource::MemTable;
@@ -48,11 +50,12 @@ use datafusion::{
4850
prelude::SessionContext,
4951
scalar::ScalarValue,
5052
};
51-
use datafusion_common::assert_contains;
53+
use datafusion_common::{assert_contains, exec_datafusion_err};
5254
use datafusion_common::{cast::as_primitive_array, exec_err};
55+
use datafusion_expr::expr::WindowFunction;
5356
use datafusion_expr::{
54-
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
55-
LogicalPlanBuilder, SimpleAggregateUDF,
57+
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, Expr,
58+
GroupsAccumulator, LogicalPlanBuilder, SimpleAggregateUDF, WindowFunctionDefinition,
5659
};
5760
use datafusion_functions_aggregate::average::AvgAccumulator;
5861

@@ -781,7 +784,7 @@ struct TestGroupsAccumulator {
781784
}
782785

783786
impl AggregateUDFImpl for TestGroupsAccumulator {
784-
fn as_any(&self) -> &dyn std::any::Any {
787+
fn as_any(&self) -> &dyn Any {
785788
self
786789
}
787790

@@ -890,3 +893,263 @@ impl GroupsAccumulator for TestGroupsAccumulator {
890893
size_of::<u64>()
891894
}
892895
}
896+
897+
#[derive(Debug)]
898+
struct MetadataBasedAggregateUdf {
899+
name: String,
900+
signature: Signature,
901+
metadata: HashMap<String, String>,
902+
}
903+
904+
impl MetadataBasedAggregateUdf {
905+
fn new(metadata: HashMap<String, String>) -> Self {
906+
// The name we return must be unique. Otherwise we will not call distinct
907+
// instances of this UDF. This is a small hack for the unit tests to get unique
908+
// names, but you could do something more elegant with the metadata.
909+
let name = format!("metadata_based_udf_{}", metadata.len());
910+
Self {
911+
name,
912+
signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable),
913+
metadata,
914+
}
915+
}
916+
}
917+
918+
impl AggregateUDFImpl for MetadataBasedAggregateUdf {
919+
fn as_any(&self) -> &dyn Any {
920+
self
921+
}
922+
923+
fn name(&self) -> &str {
924+
&self.name
925+
}
926+
927+
fn signature(&self) -> &Signature {
928+
&self.signature
929+
}
930+
931+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
932+
unimplemented!("this should never be called since return_field is implemented");
933+
}
934+
935+
fn return_field(&self, _arg_fields: &[Field]) -> Result<Field> {
936+
Ok(Field::new(self.name(), DataType::UInt64, true)
937+
.with_metadata(self.metadata.clone()))
938+
}
939+
940+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
941+
let input_expr = acc_args
942+
.exprs
943+
.first()
944+
.ok_or(exec_datafusion_err!("Expected one argument"))?;
945+
let input_field = input_expr.return_field(acc_args.schema)?;
946+
947+
let double_output = input_field
948+
.metadata()
949+
.get("modify_values")
950+
.map(|v| v == "double_output")
951+
.unwrap_or(false);
952+
953+
Ok(Box::new(MetadataBasedAccumulator {
954+
double_output,
955+
curr_sum: 0,
956+
}))
957+
}
958+
}
959+
960+
#[derive(Debug)]
961+
struct MetadataBasedAccumulator {
962+
double_output: bool,
963+
curr_sum: u64,
964+
}
965+
966+
impl Accumulator for MetadataBasedAccumulator {
967+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
968+
let arr = values[0]
969+
.as_any()
970+
.downcast_ref::<UInt64Array>()
971+
.ok_or(exec_datafusion_err!("Expected UInt64Array"))?;
972+
973+
self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0));
974+
975+
Ok(())
976+
}
977+
978+
fn evaluate(&mut self) -> Result<ScalarValue> {
979+
let v = match self.double_output {
980+
true => self.curr_sum * 2,
981+
false => self.curr_sum,
982+
};
983+
984+
Ok(ScalarValue::from(v))
985+
}
986+
987+
fn size(&self) -> usize {
988+
9
989+
}
990+
991+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
992+
Ok(vec![ScalarValue::from(self.curr_sum)])
993+
}
994+
995+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
996+
self.update_batch(states)
997+
}
998+
}
999+
1000+
#[tokio::test]
1001+
async fn test_metadata_based_aggregate() -> Result<()> {
1002+
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
1003+
let schema = Arc::new(Schema::new(vec![
1004+
Field::new("no_metadata", DataType::UInt64, true),
1005+
Field::new("with_metadata", DataType::UInt64, true).with_metadata(
1006+
[("modify_values".to_string(), "double_output".to_string())]
1007+
.into_iter()
1008+
.collect(),
1009+
),
1010+
]));
1011+
1012+
let batch = RecordBatch::try_new(
1013+
schema,
1014+
vec![Arc::clone(&data_array), Arc::clone(&data_array)],
1015+
)?;
1016+
1017+
let ctx = SessionContext::new();
1018+
ctx.register_batch("t", batch)?;
1019+
let df = ctx.table("t").await?;
1020+
1021+
let no_output_meta_udf =
1022+
AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new()));
1023+
let with_output_meta_udf = AggregateUDF::from(MetadataBasedAggregateUdf::new(
1024+
[("output_metatype".to_string(), "custom_value".to_string())]
1025+
.into_iter()
1026+
.collect(),
1027+
));
1028+
1029+
let df = df.aggregate(
1030+
vec![],
1031+
vec![
1032+
no_output_meta_udf
1033+
.call(vec![col("no_metadata")])
1034+
.alias("meta_no_in_no_out"),
1035+
no_output_meta_udf
1036+
.call(vec![col("with_metadata")])
1037+
.alias("meta_with_in_no_out"),
1038+
with_output_meta_udf
1039+
.call(vec![col("no_metadata")])
1040+
.alias("meta_no_in_with_out"),
1041+
with_output_meta_udf
1042+
.call(vec![col("with_metadata")])
1043+
.alias("meta_with_in_with_out"),
1044+
],
1045+
)?;
1046+
1047+
let actual = df.collect().await?;
1048+
1049+
// To test for output metadata handling, we set the expected values on the result
1050+
// To test for input metadata handling, we check the numbers returned
1051+
let mut output_meta = HashMap::new();
1052+
let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string());
1053+
let expected_schema = Schema::new(vec![
1054+
Field::new("meta_no_in_no_out", DataType::UInt64, true),
1055+
Field::new("meta_with_in_no_out", DataType::UInt64, true),
1056+
Field::new("meta_no_in_with_out", DataType::UInt64, true)
1057+
.with_metadata(output_meta.clone()),
1058+
Field::new("meta_with_in_with_out", DataType::UInt64, true)
1059+
.with_metadata(output_meta.clone()),
1060+
]);
1061+
1062+
let expected = record_batch!(
1063+
("meta_no_in_no_out", UInt64, [50]),
1064+
("meta_with_in_no_out", UInt64, [100]),
1065+
("meta_no_in_with_out", UInt64, [50]),
1066+
("meta_with_in_with_out", UInt64, [100])
1067+
)?
1068+
.with_schema(Arc::new(expected_schema))?;
1069+
1070+
assert_eq!(expected, actual[0]);
1071+
1072+
Ok(())
1073+
}
1074+
1075+
#[tokio::test]
1076+
async fn test_metadata_based_aggregate_as_window() -> Result<()> {
1077+
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
1078+
let schema = Arc::new(Schema::new(vec![
1079+
Field::new("no_metadata", DataType::UInt64, true),
1080+
Field::new("with_metadata", DataType::UInt64, true).with_metadata(
1081+
[("modify_values".to_string(), "double_output".to_string())]
1082+
.into_iter()
1083+
.collect(),
1084+
),
1085+
]));
1086+
1087+
let batch = RecordBatch::try_new(
1088+
schema,
1089+
vec![Arc::clone(&data_array), Arc::clone(&data_array)],
1090+
)?;
1091+
1092+
let ctx = SessionContext::new();
1093+
ctx.register_batch("t", batch)?;
1094+
let df = ctx.table("t").await?;
1095+
1096+
let no_output_meta_udf = Arc::new(AggregateUDF::from(
1097+
MetadataBasedAggregateUdf::new(HashMap::new()),
1098+
));
1099+
let with_output_meta_udf =
1100+
Arc::new(AggregateUDF::from(MetadataBasedAggregateUdf::new(
1101+
[("output_metatype".to_string(), "custom_value".to_string())]
1102+
.into_iter()
1103+
.collect(),
1104+
)));
1105+
1106+
let df = df.select(vec![
1107+
Expr::WindowFunction(WindowFunction::new(
1108+
WindowFunctionDefinition::AggregateUDF(Arc::clone(&no_output_meta_udf)),
1109+
vec![col("no_metadata")],
1110+
))
1111+
.alias("meta_no_in_no_out"),
1112+
Expr::WindowFunction(WindowFunction::new(
1113+
WindowFunctionDefinition::AggregateUDF(no_output_meta_udf),
1114+
vec![col("with_metadata")],
1115+
))
1116+
.alias("meta_with_in_no_out"),
1117+
Expr::WindowFunction(WindowFunction::new(
1118+
WindowFunctionDefinition::AggregateUDF(Arc::clone(&with_output_meta_udf)),
1119+
vec![col("no_metadata")],
1120+
))
1121+
.alias("meta_no_in_with_out"),
1122+
Expr::WindowFunction(WindowFunction::new(
1123+
WindowFunctionDefinition::AggregateUDF(with_output_meta_udf),
1124+
vec![col("with_metadata")],
1125+
))
1126+
.alias("meta_with_in_with_out"),
1127+
])?;
1128+
1129+
let actual = df.collect().await?;
1130+
1131+
// To test for output metadata handling, we set the expected values on the result
1132+
// To test for input metadata handling, we check the numbers returned
1133+
let mut output_meta = HashMap::new();
1134+
let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string());
1135+
let expected_schema = Schema::new(vec![
1136+
Field::new("meta_no_in_no_out", DataType::UInt64, true),
1137+
Field::new("meta_with_in_no_out", DataType::UInt64, true),
1138+
Field::new("meta_no_in_with_out", DataType::UInt64, true)
1139+
.with_metadata(output_meta.clone()),
1140+
Field::new("meta_with_in_with_out", DataType::UInt64, true)
1141+
.with_metadata(output_meta.clone()),
1142+
]);
1143+
1144+
let expected = record_batch!(
1145+
("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]),
1146+
("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]),
1147+
("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]),
1148+
("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100])
1149+
)?
1150+
.with_schema(Arc::new(expected_schema))?;
1151+
1152+
assert_eq!(expected, actual[0]);
1153+
1154+
Ok(())
1155+
}

0 commit comments

Comments
 (0)