Skip to content

Commit a91be04

Browse files
authored
Window UDF signature check (apache#12045)
* udwf sig Signed-off-by: jayzhan211 <[email protected]> * add coerce_types Signed-off-by: jayzhan211 <[email protected]> * add doc Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 950dc73 commit a91be04

File tree

4 files changed

+130
-9
lines changed

4 files changed

+130
-9
lines changed

datafusion/expr/src/expr_schema.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::expr::{
2222
};
2323
use crate::type_coercion::binary::get_result_type;
2424
use crate::type_coercion::functions::{
25-
data_types_with_aggregate_udf, data_types_with_scalar_udf,
25+
data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
2626
};
2727
use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
2828
use arrow::compute::can_cast_types;
@@ -191,6 +191,21 @@ impl ExprSchemable for Expr {
191191
})?;
192192
Ok(fun.return_type(&new_types, &nullability)?)
193193
}
194+
WindowFunctionDefinition::WindowUDF(udwf) => {
195+
let new_types = data_types_with_window_udf(&data_types, udwf)
196+
.map_err(|err| {
197+
plan_datafusion_err!(
198+
"{} {}",
199+
err,
200+
utils::generate_signature_error_msg(
201+
fun.name(),
202+
fun.signature().clone(),
203+
&data_types
204+
)
205+
)
206+
})?;
207+
Ok(fun.return_type(&new_types, &nullability)?)
208+
}
194209
_ => fun.return_type(&data_types, &nullability),
195210
}
196211
}

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,21 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
20-
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature};
18+
use super::binary::{binary_numeric_coercion, comparison_coercion};
19+
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
2120
use arrow::{
2221
compute::can_cast_types,
2322
datatypes::{DataType, TimeUnit},
2423
};
25-
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
2624
use datafusion_common::{
27-
exec_err, internal_datafusion_err, internal_err, plan_err, Result,
25+
exec_err, internal_datafusion_err, internal_err, plan_err,
26+
utils::{coerced_fixed_size_list_to_list, list_ndims},
27+
Result,
2828
};
2929
use datafusion_expr_common::signature::{
3030
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
3131
};
32-
33-
use super::binary::{binary_numeric_coercion, comparison_coercion};
32+
use std::sync::Arc;
3433

3534
/// Performs type coercion for scalar function arguments.
3635
///
@@ -66,6 +65,13 @@ pub fn data_types_with_scalar_udf(
6665
try_coerce_types(valid_types, current_types, &signature.type_signature)
6766
}
6867

68+
/// Performs type coercion for aggregate function arguments.
69+
///
70+
/// Returns the data types to which each argument must be coerced to
71+
/// match `signature`.
72+
///
73+
/// For more details on coercion in general, please see the
74+
/// [`type_coercion`](crate::type_coercion) module.
6975
pub fn data_types_with_aggregate_udf(
7076
current_types: &[DataType],
7177
func: &AggregateUDF,
@@ -95,6 +101,39 @@ pub fn data_types_with_aggregate_udf(
95101
try_coerce_types(valid_types, current_types, &signature.type_signature)
96102
}
97103

104+
/// Performs type coercion for window function arguments.
105+
///
106+
/// Returns the data types to which each argument must be coerced to
107+
/// match `signature`.
108+
///
109+
/// For more details on coercion in general, please see the
110+
/// [`type_coercion`](crate::type_coercion) module.
111+
pub fn data_types_with_window_udf(
112+
current_types: &[DataType],
113+
func: &WindowUDF,
114+
) -> Result<Vec<DataType>> {
115+
let signature = func.signature();
116+
117+
if current_types.is_empty() {
118+
if signature.type_signature.supports_zero_argument() {
119+
return Ok(vec![]);
120+
} else {
121+
return plan_err!("{} does not support zero arguments.", func.name());
122+
}
123+
}
124+
125+
let valid_types =
126+
get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?;
127+
if valid_types
128+
.iter()
129+
.any(|data_type| data_type == current_types)
130+
{
131+
return Ok(current_types.to_vec());
132+
}
133+
134+
try_coerce_types(valid_types, current_types, &signature.type_signature)
135+
}
136+
98137
/// Performs type coercion for function arguments.
99138
///
100139
/// Returns the data types to which each argument must be coerced to
@@ -205,6 +244,27 @@ fn get_valid_types_with_aggregate_udf(
205244
Ok(valid_types)
206245
}
207246

247+
fn get_valid_types_with_window_udf(
248+
signature: &TypeSignature,
249+
current_types: &[DataType],
250+
func: &WindowUDF,
251+
) -> Result<Vec<Vec<DataType>>> {
252+
let valid_types = match signature {
253+
TypeSignature::UserDefined => match func.coerce_types(current_types) {
254+
Ok(coerced_types) => vec![coerced_types],
255+
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
256+
},
257+
TypeSignature::OneOf(signatures) => signatures
258+
.iter()
259+
.filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
260+
.flatten()
261+
.collect::<Vec<_>>(),
262+
_ => get_valid_types(signature, current_types)?,
263+
};
264+
265+
Ok(valid_types)
266+
}
267+
208268
/// Returns a Vec of all possible valid argument types for the given signature.
209269
fn get_valid_types(
210270
signature: &TypeSignature,

datafusion/expr/src/udwf.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use std::{
2727

2828
use arrow::datatypes::DataType;
2929

30-
use datafusion_common::Result;
30+
use datafusion_common::{not_impl_err, Result};
3131

3232
use crate::expr::WindowFunction;
3333
use crate::{
@@ -192,6 +192,11 @@ impl WindowUDF {
192192
pub fn sort_options(&self) -> Option<SortOptions> {
193193
self.inner.sort_options()
194194
}
195+
196+
/// See [`WindowUDFImpl::coerce_types`] for more details.
197+
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
198+
self.inner.coerce_types(arg_types)
199+
}
195200
}
196201

197202
impl<F> From<F> for WindowUDF
@@ -353,6 +358,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
353358
fn sort_options(&self) -> Option<SortOptions> {
354359
None
355360
}
361+
362+
/// Coerce arguments of a function call to types that the function can evaluate.
363+
///
364+
/// This function is only called if [`WindowUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
365+
/// UDWFs should return one of the other variants of `TypeSignature` which handle common
366+
/// cases
367+
///
368+
/// See the [type coercion module](crate::type_coercion)
369+
/// documentation for more details on type coercion
370+
///
371+
/// For example, if your function requires a floating point arguments, but the user calls
372+
/// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
373+
/// to ensure the argument was cast to `1::double`
374+
///
375+
/// # Parameters
376+
/// * `arg_types`: The argument types of the arguments this function with
377+
///
378+
/// # Return value
379+
/// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
380+
/// arguments to these specific types.
381+
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
382+
not_impl_err!("Function {} does not implement coerce_types", self.name())
383+
}
356384
}
357385

358386
/// WindowUDF that adds an alias to the underlying function. It is better to

datafusion/sqllogictest/test_files/window.slt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4879,3 +4879,21 @@ SELECT lead(column2, 1.1) OVER (order by column1) FROM t;
48794879

48804880
query error DataFusion error: Execution error: Expected an integer value
48814881
SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t;
4882+
4883+
statement ok
4884+
drop table t;
4885+
4886+
statement ok
4887+
create table t(a int, b int) as values (1, 2)
4888+
4889+
query II
4890+
select a, row_number() over (order by b) as rn from t;
4891+
----
4892+
1 1
4893+
4894+
# RowNumber expect 0 args.
4895+
query error
4896+
select a, row_number(a) over (order by b) as rn from t;
4897+
4898+
statement ok
4899+
drop table t;

0 commit comments

Comments
 (0)