Skip to content

Commit f799fca

Browse files
committed
Move window functions over to use Field instead of DataType
1 parent 4bd5a20 commit f799fca

File tree

11 files changed

+285
-100
lines changed

11 files changed

+285
-100
lines changed

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
//! This module contains end to end tests of creating
1919
//! user defined window functions
2020
21-
use arrow::array::{ArrayRef, AsArray, Int64Array, RecordBatch, StringArray};
21+
use arrow::array::{
22+
record_batch, Array, ArrayRef, AsArray, Int64Array, RecordBatch, StringArray,
23+
UInt64Array,
24+
};
2225
use arrow::datatypes::{DataType, Field, Schema};
2326
use datafusion::common::test_util::batches_to_string;
2427
use datafusion::common::{Result, ScalarValue};
2528
use datafusion::prelude::SessionContext;
29+
use datafusion_common::exec_datafusion_err;
2630
use datafusion_expr::{
2731
PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
2832
};
@@ -34,6 +38,7 @@ use datafusion_physical_expr::{
3438
expressions::{col, lit},
3539
PhysicalExpr,
3640
};
41+
use std::collections::HashMap;
3742
use std::{
3843
any::Any,
3944
ops::Range,
@@ -723,11 +728,11 @@ fn test_default_expressions() -> Result<()> {
723728
];
724729

725730
for input_exprs in &test_cases {
726-
let input_types = input_exprs
731+
let input_fields = input_exprs
727732
.iter()
728-
.map(|expr: &Arc<dyn PhysicalExpr>| expr.data_type(&schema).unwrap())
733+
.map(|expr: &Arc<dyn PhysicalExpr>| expr.return_field(&schema).unwrap())
729734
.collect::<Vec<_>>();
730-
let expr_args = ExpressionArgs::new(input_exprs, &input_types);
735+
let expr_args = ExpressionArgs::new(input_exprs, &input_fields);
731736

732737
let ret_exprs = udwf.expressions(expr_args);
733738

@@ -753,3 +758,148 @@ fn test_default_expressions() -> Result<()> {
753758
}
754759
Ok(())
755760
}
761+
762+
#[derive(Debug)]
763+
struct MetadataBasedWindowUdf {
764+
name: String,
765+
signature: Signature,
766+
metadata: HashMap<String, String>,
767+
}
768+
769+
impl MetadataBasedWindowUdf {
770+
fn new(metadata: HashMap<String, String>) -> Self {
771+
// The name we return must be unique. Otherwise we will not call distinct
772+
// instances of this UDF. This is a small hack for the unit tests to get unique
773+
// names, but you could do something more elegant with the metadata.
774+
let name = format!("metadata_based_udf_{}", metadata.len());
775+
Self {
776+
name,
777+
signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable),
778+
metadata,
779+
}
780+
}
781+
}
782+
783+
impl WindowUDFImpl for MetadataBasedWindowUdf {
784+
fn as_any(&self) -> &dyn Any {
785+
self
786+
}
787+
788+
fn name(&self) -> &str {
789+
&self.name
790+
}
791+
792+
fn signature(&self) -> &Signature {
793+
&self.signature
794+
}
795+
796+
fn partition_evaluator(
797+
&self,
798+
partition_evaluator_args: PartitionEvaluatorArgs,
799+
) -> Result<Box<dyn PartitionEvaluator>> {
800+
let input_field = partition_evaluator_args
801+
.input_fields()
802+
.first()
803+
.ok_or(exec_datafusion_err!("Expected one argument"))?;
804+
805+
let double_output = input_field
806+
.metadata()
807+
.get("modify_values")
808+
.map(|v| v == "double_output")
809+
.unwrap_or(false);
810+
811+
Ok(Box::new(MetadataBasedPartitionEvaluator { double_output }))
812+
}
813+
814+
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
815+
Ok(Field::new(field_args.name(), DataType::UInt64, true)
816+
.with_metadata(self.metadata.clone()))
817+
}
818+
}
819+
820+
#[derive(Debug)]
821+
struct MetadataBasedPartitionEvaluator {
822+
double_output: bool,
823+
}
824+
825+
impl PartitionEvaluator for MetadataBasedPartitionEvaluator {
826+
fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
827+
let values = values[0].as_any().downcast_ref::<UInt64Array>().unwrap();
828+
let sum = values.iter().fold(0_u64, |acc, v| acc + v.unwrap_or(0));
829+
830+
let result = if self.double_output { sum * 2 } else { sum };
831+
832+
Ok(Arc::new(UInt64Array::from_value(result, num_rows)))
833+
}
834+
}
835+
836+
#[tokio::test]
837+
async fn test_metadata_based_window_fn() -> Result<()> {
838+
let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef;
839+
let schema = Arc::new(Schema::new(vec![
840+
Field::new("no_metadata", DataType::UInt64, true),
841+
Field::new("with_metadata", DataType::UInt64, true).with_metadata(
842+
[("modify_values".to_string(), "double_output".to_string())]
843+
.into_iter()
844+
.collect(),
845+
),
846+
]));
847+
848+
let batch = RecordBatch::try_new(
849+
schema,
850+
vec![Arc::clone(&data_array), Arc::clone(&data_array)],
851+
)?;
852+
853+
let ctx = SessionContext::new();
854+
ctx.register_batch("t", batch)?;
855+
let df = ctx.table("t").await?;
856+
857+
let no_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new(HashMap::new()));
858+
let with_output_meta_udf = WindowUDF::from(MetadataBasedWindowUdf::new(
859+
[("output_metatype".to_string(), "custom_value".to_string())]
860+
.into_iter()
861+
.collect(),
862+
));
863+
864+
let df = df.select(vec![
865+
no_output_meta_udf
866+
.call(vec![datafusion_expr::col("no_metadata")])
867+
.alias("meta_no_in_no_out"),
868+
no_output_meta_udf
869+
.call(vec![datafusion_expr::col("with_metadata")])
870+
.alias("meta_with_in_no_out"),
871+
with_output_meta_udf
872+
.call(vec![datafusion_expr::col("no_metadata")])
873+
.alias("meta_no_in_with_out"),
874+
with_output_meta_udf
875+
.call(vec![datafusion_expr::col("with_metadata")])
876+
.alias("meta_with_in_with_out"),
877+
])?;
878+
879+
let actual = df.collect().await?;
880+
881+
// To test for output metadata handling, we set the expected values on the result
882+
// To test for input metadata handling, we check the numbers returned
883+
let mut output_meta = HashMap::new();
884+
let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string());
885+
let expected_schema = Schema::new(vec![
886+
Field::new("meta_no_in_no_out", DataType::UInt64, true),
887+
Field::new("meta_with_in_no_out", DataType::UInt64, true),
888+
Field::new("meta_no_in_with_out", DataType::UInt64, true)
889+
.with_metadata(output_meta.clone()),
890+
Field::new("meta_with_in_with_out", DataType::UInt64, true)
891+
.with_metadata(output_meta.clone()),
892+
]);
893+
894+
let expected = record_batch!(
895+
("meta_no_in_no_out", UInt64, [50, 50, 50, 50, 50]),
896+
("meta_with_in_no_out", UInt64, [100, 100, 100, 100, 100]),
897+
("meta_no_in_with_out", UInt64, [50, 50, 50, 50, 50]),
898+
("meta_with_in_with_out", UInt64, [100, 100, 100, 100, 100])
899+
)?
900+
.with_schema(Arc::new(expected_schema))?;
901+
902+
assert_eq!(expected, actual[0]);
903+
904+
Ok(())
905+
}

datafusion/expr/src/expr.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::logical_plan::Subquery;
2828
use crate::Volatility;
2929
use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};
3030

31-
use arrow::datatypes::{DataType, FieldRef};
31+
use arrow::datatypes::{DataType, Field, FieldRef};
3232
use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
3333
use datafusion_common::tree_node::{
3434
Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion,
@@ -838,19 +838,19 @@ pub enum WindowFunctionDefinition {
838838

839839
impl WindowFunctionDefinition {
840840
/// Returns the datatype of the window function
841-
pub fn return_type(
841+
pub fn return_field(
842842
&self,
843-
input_expr_types: &[DataType],
843+
input_expr_fields: &[Field],
844844
_input_expr_nullable: &[bool],
845845
display_name: &str,
846-
) -> Result<DataType> {
846+
) -> Result<Field> {
847847
match self {
848848
WindowFunctionDefinition::AggregateUDF(fun) => {
849-
fun.return_type(input_expr_types)
849+
fun.return_field(input_expr_fields)
850+
}
851+
WindowFunctionDefinition::WindowUDF(fun) => {
852+
fun.field(WindowUDFFieldArgs::new(input_expr_fields, display_name))
850853
}
851-
WindowFunctionDefinition::WindowUDF(fun) => fun
852-
.field(WindowUDFFieldArgs::new(input_expr_types, display_name))
853-
.map(|field| field.data_type().clone()),
854854
}
855855
}
856856

datafusion/expr/src/expr_schema.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,8 +647,8 @@ impl Expr {
647647
.map(|f| f.data_type())
648648
.cloned()
649649
.collect::<Vec<_>>();
650-
let new_types =
651-
data_types_with_window_udf(&data_types, udwf).map_err(|err| {
650+
let new_fields =
651+
data_types_with_window_udf(&fields, udwf).map_err(|err| {
652652
plan_datafusion_err!(
653653
"{} {}",
654654
match err {
@@ -663,7 +663,7 @@ impl Expr {
663663
)
664664
})?;
665665
let (_, function_name) = self.qualified_name();
666-
let field_args = WindowUDFFieldArgs::new(&new_types, &function_name);
666+
let field_args = WindowUDFFieldArgs::new(&new_fields, &function_name);
667667

668668
udwf.field(field_args)
669669
.map(|field| (field.data_type().clone(), field.is_nullable()))

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,13 @@ pub fn data_types_with_aggregate_udf(
130130
/// For more details on coercion in general, please see the
131131
/// [`type_coercion`](crate::type_coercion) module.
132132
pub fn data_types_with_window_udf(
133-
current_types: &[DataType],
133+
current_fields: &[Field],
134134
func: &WindowUDF,
135-
) -> Result<Vec<DataType>> {
135+
) -> Result<Vec<Field>> {
136136
let signature = func.signature();
137137
let type_signature = &signature.type_signature;
138138

139-
if current_types.is_empty() && type_signature != &TypeSignature::UserDefined {
139+
if current_fields.is_empty() && type_signature != &TypeSignature::UserDefined {
140140
if type_signature.supports_zero_argument() {
141141
return Ok(vec![]);
142142
} else if type_signature.used_to_support_zero_arguments() {
@@ -147,16 +147,28 @@ pub fn data_types_with_window_udf(
147147
}
148148
}
149149

150+
let current_types = current_fields
151+
.iter()
152+
.map(|f| f.data_type())
153+
.cloned()
154+
.collect::<Vec<_>>();
150155
let valid_types =
151-
get_valid_types_with_window_udf(type_signature, current_types, func)?;
156+
get_valid_types_with_window_udf(type_signature, &current_types, func)?;
152157
if valid_types
153158
.iter()
154-
.any(|data_type| data_type == current_types)
159+
.any(|data_type| data_type == &current_types)
155160
{
156-
return Ok(current_types.to_vec());
161+
return Ok(current_fields.to_vec());
157162
}
158163

159-
try_coerce_types(func.name(), valid_types, current_types, type_signature)
164+
let updated_types =
165+
try_coerce_types(func.name(), valid_types, &current_types, type_signature)?;
166+
167+
Ok(current_fields
168+
.iter()
169+
.zip(updated_types)
170+
.map(|(current_field, new_type)| current_field.clone().with_data_type(new_type))
171+
.collect())
160172
}
161173

162174
/// Performs type coercion for function arguments.

datafusion/functions-window-common/src/expr.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion_common::arrow::datatypes::DataType;
18+
use datafusion_common::arrow::datatypes::Field;
1919
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
2020
use std::sync::Arc;
2121

@@ -27,7 +27,7 @@ pub struct ExpressionArgs<'a> {
2727
input_exprs: &'a [Arc<dyn PhysicalExpr>],
2828
/// The corresponding data types of expressions passed as arguments
2929
/// to the user-defined window function.
30-
input_types: &'a [DataType],
30+
input_fields: &'a [Field],
3131
}
3232

3333
impl<'a> ExpressionArgs<'a> {
@@ -42,11 +42,11 @@ impl<'a> ExpressionArgs<'a> {
4242
///
4343
pub fn new(
4444
input_exprs: &'a [Arc<dyn PhysicalExpr>],
45-
input_types: &'a [DataType],
45+
input_fields: &'a [Field],
4646
) -> Self {
4747
Self {
4848
input_exprs,
49-
input_types,
49+
input_fields,
5050
}
5151
}
5252

@@ -56,9 +56,9 @@ impl<'a> ExpressionArgs<'a> {
5656
self.input_exprs
5757
}
5858

59-
/// Returns the [`DataType`]s corresponding to the input expressions
59+
/// Returns the [`Field`]s corresponding to the input expressions
6060
/// to the user-defined window function.
61-
pub fn input_types(&self) -> &'a [DataType] {
62-
self.input_types
61+
pub fn input_fields(&self) -> &'a [Field] {
62+
self.input_fields
6363
}
6464
}

datafusion/functions-window-common/src/field.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion_common::arrow::datatypes::DataType;
18+
use datafusion_common::arrow::datatypes::Field;
1919

2020
/// Metadata for defining the result field from evaluating a
2121
/// user-defined window function.
2222
pub struct WindowUDFFieldArgs<'a> {
2323
/// The data types corresponding to the arguments to the
2424
/// user-defined window function.
25-
input_types: &'a [DataType],
25+
input_fields: &'a [Field],
2626
/// The display name of the user-defined window function.
2727
display_name: &'a str,
2828
}
@@ -32,22 +32,22 @@ impl<'a> WindowUDFFieldArgs<'a> {
3232
///
3333
/// # Arguments
3434
///
35-
/// * `input_types` - The data types corresponding to the
35+
/// * `input_fields` - The fields corresponding to the
3636
/// arguments to the user-defined window function.
3737
/// * `function_name` - The qualified schema name of the
3838
/// user-defined window function expression.
3939
///
40-
pub fn new(input_types: &'a [DataType], display_name: &'a str) -> Self {
40+
pub fn new(input_fields: &'a [Field], display_name: &'a str) -> Self {
4141
WindowUDFFieldArgs {
42-
input_types,
42+
input_fields,
4343
display_name,
4444
}
4545
}
4646

47-
/// Returns the data type of input expressions passed as arguments
47+
/// Returns the field of input expressions passed as arguments
4848
/// to the user-defined window function.
49-
pub fn input_types(&self) -> &[DataType] {
50-
self.input_types
49+
pub fn input_fields(&self) -> &[Field] {
50+
self.input_fields
5151
}
5252

5353
/// Returns the name for the field of the final result of evaluating
@@ -56,9 +56,9 @@ impl<'a> WindowUDFFieldArgs<'a> {
5656
self.display_name
5757
}
5858

59-
/// Returns `Some(DataType)` of input expression at index, otherwise
59+
/// Returns `Some(Field)` of input expression at index, otherwise
6060
/// returns `None` if the index is out of bounds.
61-
pub fn get_input_type(&self, index: usize) -> Option<DataType> {
62-
self.input_types.get(index).cloned()
61+
pub fn get_input_field(&self, index: usize) -> Option<Field> {
62+
self.input_fields.get(index).cloned()
6363
}
6464
}

0 commit comments

Comments
 (0)