Skip to content

Commit 54de5ab

Browse files
committed
More work on moving to Field from DataType for aggregates
1 parent 8a848fd commit 54de5ab

File tree

20 files changed

+255
-100
lines changed

20 files changed

+255
-100
lines changed

datafusion-examples/examples/advanced_udaf.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl AggregateUDFImpl for GeoMeanUdaf {
9494
/// This is the description of the state. accumulator's state() must match the types here.
9595
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
9696
Ok(vec![
97-
Field::new("prod", args.return_type.clone(), true),
97+
Field::new("prod", args.return_field.data_type().clone(), true),
9898
Field::new("n", DataType::UInt32, true),
9999
])
100100
}

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 137 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
//! This module contains end to end demonstrations of creating
1919
//! user defined aggregate functions
2020
21+
use std::any::Any;
22+
use std::collections::HashMap;
2123
use std::hash::{DefaultHasher, Hash, Hasher};
2224
use std::mem::{size_of, size_of_val};
2325
use std::sync::{
@@ -26,10 +28,10 @@ use std::sync::{
2628
};
2729

2830
use arrow::array::{
29-
types::UInt64Type, AsArray, Int32Array, PrimitiveArray, StringArray, StructArray,
31+
types::UInt64Type, Array, AsArray, Int32Array, PrimitiveArray, StringArray,
32+
StructArray, UInt64Array,
3033
};
3134
use arrow::datatypes::{Fields, Schema};
32-
3335
use datafusion::common::test_util::batches_to_string;
3436
use datafusion::dataframe::DataFrame;
3537
use datafusion::datasource::MemTable;
@@ -48,7 +50,7 @@ use datafusion::{
4850
prelude::SessionContext,
4951
scalar::ScalarValue,
5052
};
51-
use datafusion_common::assert_contains;
53+
use datafusion_common::{assert_contains, exec_datafusion_err};
5254
use datafusion_common::{cast::as_primitive_array, exec_err};
5355
use datafusion_expr::{
5456
col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
@@ -781,7 +783,7 @@ struct TestGroupsAccumulator {
781783
}
782784

783785
impl AggregateUDFImpl for TestGroupsAccumulator {
784-
fn as_any(&self) -> &dyn std::any::Any {
786+
fn as_any(&self) -> &dyn Any {
785787
self
786788
}
787789

@@ -890,3 +892,134 @@ impl GroupsAccumulator for TestGroupsAccumulator {
890892
size_of::<u64>()
891893
}
892894
}
895+
896+
#[derive(Debug)]
897+
struct MetadataBasedAggregateUdf {
898+
name: String,
899+
signature: Signature,
900+
metadata: HashMap<String, String>,
901+
}
902+
903+
impl MetadataBasedAggregateUdf {
904+
fn new(metadata: HashMap<String, String>) -> Self {
905+
// The name we return must be unique. Otherwise we will not call distinct
906+
// instances of this UDF. This is a small hack for the unit tests to get unique
907+
// names, but you could do something more elegant with the metadata.
908+
let name = format!("metadata_based_udf_{}", metadata.len());
909+
Self {
910+
name,
911+
signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable),
912+
metadata,
913+
}
914+
}
915+
}
916+
917+
impl AggregateUDFImpl for MetadataBasedAggregateUdf {
918+
fn as_any(&self) -> &dyn Any {
919+
self
920+
}
921+
922+
fn name(&self) -> &str {
923+
&self.name
924+
}
925+
926+
fn signature(&self) -> &Signature {
927+
&self.signature
928+
}
929+
930+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
931+
unimplemented!("this should never be called since return_field is implemented");
932+
}
933+
934+
fn return_field(&self, _arg_fields: &[Field]) -> Result<Field> {
935+
Ok(Field::new(self.name(), DataType::UInt64, true)
936+
.with_metadata(self.metadata.clone()))
937+
}
938+
939+
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
940+
let input_expr = acc_args
941+
.exprs
942+
.first()
943+
.ok_or(exec_datafusion_err!("Expected one argument"))?;
944+
let input_field = input_expr.return_field(acc_args.schema)?;
945+
946+
let double_output = input_field
947+
.metadata()
948+
.get("modify_values")
949+
.map(|v| v == "double_output")
950+
.unwrap_or(false);
951+
952+
Ok(Box::new(MetadataBasedAccumulator {
953+
double_output,
954+
curr_sum: 0,
955+
}))
956+
}
957+
}
958+
959+
#[derive(Debug)]
960+
struct MetadataBasedAccumulator {
961+
double_output: bool,
962+
curr_sum: u64,
963+
}
964+
965+
impl Accumulator for MetadataBasedAccumulator {
966+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
967+
let arr = values[0]
968+
.as_any()
969+
.downcast_ref::<UInt64Array>()
970+
.ok_or(exec_datafusion_err!("Expected UInt64Array"))?;
971+
972+
self.curr_sum = arr.iter().fold(self.curr_sum, |a, b| a + b.unwrap_or(0));
973+
974+
Ok(())
975+
}
976+
977+
fn evaluate(&mut self) -> Result<ScalarValue> {
978+
let v = match self.double_output {
979+
true => self.curr_sum * 2,
980+
false => self.curr_sum,
981+
};
982+
983+
Ok(ScalarValue::from(v))
984+
}
985+
986+
fn size(&self) -> usize {
987+
9
988+
}
989+
990+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
991+
Ok(vec![ScalarValue::from(self.curr_sum)])
992+
}
993+
994+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
995+
self.update_batch(states)
996+
}
997+
}
998+
999+
#[tokio::test]
1000+
async fn test_metadata_based_accumulator() -> Result<()> {
1001+
let ctx = SessionContext::new();
1002+
let arr = UInt64Array::from(vec![1]);
1003+
let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(arr) as _)])?;
1004+
ctx.register_batch("t", batch).unwrap();
1005+
1006+
let udaf_no_metadata =
1007+
AggregateUDF::from(MetadataBasedAggregateUdf::new(HashMap::new()));
1008+
let udaf_with_metadata = AggregateUDF::from(MetadataBasedAggregateUdf::new(
1009+
HashMap::from_iter([("modify_values".to_string(), "double_output".to_string())]),
1010+
));
1011+
ctx.register_udaf(udaf_no_metadata);
1012+
ctx.register_udaf(udaf_with_metadata);
1013+
1014+
let sql_df = ctx
1015+
.sql("SELECT metadata_based_udf_0(a) FROM t group by a")
1016+
.await?;
1017+
sql_df.show().await?;
1018+
1019+
let sql_df = ctx
1020+
.sql("SELECT metadata_based_udf_1(a) FROM t group by a")
1021+
.await?;
1022+
sql_df.show().await?;
1023+
1024+
Ok(())
1025+
}

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::signature::TypeSignature;
1919
use arrow::datatypes::{
20-
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
20+
DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
2121
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
2222
};
2323

@@ -89,7 +89,7 @@ pub static TIMES: &[DataType] = &[
8989
/// number of input types.
9090
pub fn check_arg_count(
9191
func_name: &str,
92-
input_types: &[DataType],
92+
input_types: &[Field],
9393
signature: &TypeSignature,
9494
) -> Result<()> {
9595
match signature {

datafusion/expr/src/udaf.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ where
410410
/// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { unimplemented!() }
411411
/// fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
412412
/// Ok(vec![
413-
/// Field::new("value", args.return_type.clone(), true),
413+
/// args.return_field.clone().with_name("value"),
414414
/// Field::new("ordering", DataType::UInt32, true)
415415
/// ])
416416
/// }
@@ -745,11 +745,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
745745
/// be derived from `name`. See [`format_state_name`] for a utility function
746746
/// to generate a unique name.
747747
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
748-
let fields = vec![Field::new(
749-
format_state_name(args.name, "value"),
750-
args.return_type.clone(),
751-
true,
752-
)];
748+
let fields = vec![args
749+
.return_field
750+
.clone()
751+
.with_name(format_state_name(args.name, "value"))];
753752

754753
Ok(fields
755754
.into_iter()

datafusion/functions-aggregate-common/src/accumulator.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::{DataType, Field, Schema};
18+
use arrow::datatypes::{Field, Schema};
1919
use datafusion_common::Result;
2020
use datafusion_expr_common::accumulator::Accumulator;
2121
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
@@ -28,7 +28,7 @@ use std::sync::Arc;
2828
#[derive(Debug)]
2929
pub struct AccumulatorArgs<'a> {
3030
/// The return type of the aggregate function.
31-
pub return_type: &'a DataType,
31+
pub return_field: &'a Field,
3232

3333
/// The schema of the input arguments
3434
pub schema: &'a Schema,
@@ -81,11 +81,11 @@ pub struct StateFieldsArgs<'a> {
8181
/// The name of the aggregate function.
8282
pub name: &'a str,
8383

84-
/// The input types of the aggregate function.
85-
pub input_types: &'a [DataType],
84+
/// The input fields of the aggregate function.
85+
pub input_fields: &'a [Field],
8686

87-
/// The return type of the aggregate function.
88-
pub return_type: &'a DataType,
87+
/// The return fields of the aggregate function.
88+
pub return_field: &'a Field,
8989

9090
/// The ordering fields of the aggregate function.
9191
pub ordering_fields: &'a [Field],

datafusion/functions-aggregate/benches/count.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use std::sync::Arc;
2828
fn prepare_accumulator() -> Box<dyn GroupsAccumulator> {
2929
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)]));
3030
let accumulator_args = AccumulatorArgs {
31-
return_type: &DataType::Int64,
31+
return_field: &Field::new("f", DataType::Int64, true),
3232
schema: &schema,
3333
ignore_nulls: false,
3434
ordering_req: &LexOrdering::default(),

datafusion/functions-aggregate/benches/sum.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ use datafusion_physical_expr_common::sort_expr::LexOrdering;
2626
use std::sync::Arc;
2727

2828
fn prepare_accumulator(data_type: &DataType) -> Box<dyn GroupsAccumulator> {
29-
let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)]));
29+
let field = Field::new("f", data_type.clone(), true);
30+
let schema = Arc::new(Schema::new(vec![field.clone()]));
3031
let accumulator_args = AccumulatorArgs {
31-
return_type: data_type,
32+
return_field: &field,
3233
schema: &schema,
3334
ignore_nulls: false,
3435
ordering_req: &LexOrdering::default(),

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,15 @@ impl AggregateUDFImpl for ArrayAgg {
114114
return Ok(vec![Field::new_list(
115115
format_state_name(args.name, "distinct_array_agg"),
116116
// See COMMENTS.md to understand why nullable is set to true
117-
Field::new_list_field(args.input_types[0].clone(), true),
117+
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
118118
true,
119119
)]);
120120
}
121121

122122
let mut fields = vec![Field::new_list(
123123
format_state_name(args.name, "array_agg"),
124124
// See COMMENTS.md to understand why nullable is set to true
125-
Field::new_list_field(args.input_types[0].clone(), true),
125+
Field::new_list_field(args.input_fields[0].data_type().clone(), true),
126126
true,
127127
)];
128128

@@ -984,7 +984,7 @@ mod tests {
984984
}
985985

986986
struct ArrayAggAccumulatorBuilder {
987-
data_type: DataType,
987+
return_field: Field,
988988
distinct: bool,
989989
ordering: LexOrdering,
990990
schema: Schema,
@@ -997,7 +997,7 @@ mod tests {
997997

998998
fn new(data_type: DataType) -> Self {
999999
Self {
1000-
data_type: data_type.clone(),
1000+
return_field: Field::new("f", data_type.clone(), true),
10011001
distinct: false,
10021002
ordering: Default::default(),
10031003
schema: Schema {
@@ -1029,7 +1029,7 @@ mod tests {
10291029

10301030
fn build(&self) -> Result<Box<dyn Accumulator>> {
10311031
ArrayAgg::default().accumulator(AccumulatorArgs {
1032-
return_type: &self.data_type,
1032+
return_field: &self.return_field,
10331033
schema: &self.schema,
10341034
ignore_nulls: false,
10351035
ordering_req: &self.ordering,

0 commit comments

Comments
 (0)