|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use std::sync::Arc; |
19 |
| - |
20 |
| -use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature}; |
| 18 | +use super::binary::{binary_numeric_coercion, comparison_coercion}; |
| 19 | +use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; |
21 | 20 | use arrow::{
|
22 | 21 | compute::can_cast_types,
|
23 | 22 | datatypes::{DataType, TimeUnit},
|
24 | 23 | };
|
25 |
| -use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; |
26 | 24 | use datafusion_common::{
|
27 |
| - exec_err, internal_datafusion_err, internal_err, plan_err, Result, |
| 25 | + exec_err, internal_datafusion_err, internal_err, plan_err, |
| 26 | + utils::{coerced_fixed_size_list_to_list, list_ndims}, |
| 27 | + Result, |
28 | 28 | };
|
29 | 29 | use datafusion_expr_common::signature::{
|
30 | 30 | ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
|
31 | 31 | };
|
32 |
| - |
33 |
| -use super::binary::{binary_numeric_coercion, comparison_coercion}; |
| 32 | +use std::sync::Arc; |
34 | 33 |
|
35 | 34 | /// Performs type coercion for scalar function arguments.
|
36 | 35 | ///
|
@@ -66,6 +65,13 @@ pub fn data_types_with_scalar_udf(
|
66 | 65 | try_coerce_types(valid_types, current_types, &signature.type_signature)
|
67 | 66 | }
|
68 | 67 |
|
| 68 | +/// Performs type coercion for aggregate function arguments. |
| 69 | +/// |
| 70 | +/// Returns the data types to which each argument must be coerced to |
| 71 | +/// match `signature`. |
| 72 | +/// |
| 73 | +/// For more details on coercion in general, please see the |
| 74 | +/// [`type_coercion`](crate::type_coercion) module. |
69 | 75 | pub fn data_types_with_aggregate_udf(
|
70 | 76 | current_types: &[DataType],
|
71 | 77 | func: &AggregateUDF,
|
@@ -95,6 +101,39 @@ pub fn data_types_with_aggregate_udf(
|
95 | 101 | try_coerce_types(valid_types, current_types, &signature.type_signature)
|
96 | 102 | }
|
97 | 103 |
|
| 104 | +/// Performs type coercion for window function arguments. |
| 105 | +/// |
| 106 | +/// Returns the data types to which each argument must be coerced to |
| 107 | +/// match `signature`. |
| 108 | +/// |
| 109 | +/// For more details on coercion in general, please see the |
| 110 | +/// [`type_coercion`](crate::type_coercion) module. |
| 111 | +pub fn data_types_with_window_udf( |
| 112 | + current_types: &[DataType], |
| 113 | + func: &WindowUDF, |
| 114 | +) -> Result<Vec<DataType>> { |
| 115 | + let signature = func.signature(); |
| 116 | + |
| 117 | + if current_types.is_empty() { |
| 118 | + if signature.type_signature.supports_zero_argument() { |
| 119 | + return Ok(vec![]); |
| 120 | + } else { |
| 121 | + return plan_err!("{} does not support zero arguments.", func.name()); |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + let valid_types = |
| 126 | + get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; |
| 127 | + if valid_types |
| 128 | + .iter() |
| 129 | + .any(|data_type| data_type == current_types) |
| 130 | + { |
| 131 | + return Ok(current_types.to_vec()); |
| 132 | + } |
| 133 | + |
| 134 | + try_coerce_types(valid_types, current_types, &signature.type_signature) |
| 135 | +} |
| 136 | + |
98 | 137 | /// Performs type coercion for function arguments.
|
99 | 138 | ///
|
100 | 139 | /// Returns the data types to which each argument must be coerced to
|
@@ -205,6 +244,27 @@ fn get_valid_types_with_aggregate_udf(
|
205 | 244 | Ok(valid_types)
|
206 | 245 | }
|
207 | 246 |
|
| 247 | +fn get_valid_types_with_window_udf( |
| 248 | + signature: &TypeSignature, |
| 249 | + current_types: &[DataType], |
| 250 | + func: &WindowUDF, |
| 251 | +) -> Result<Vec<Vec<DataType>>> { |
| 252 | + let valid_types = match signature { |
| 253 | + TypeSignature::UserDefined => match func.coerce_types(current_types) { |
| 254 | + Ok(coerced_types) => vec![coerced_types], |
| 255 | + Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), |
| 256 | + }, |
| 257 | + TypeSignature::OneOf(signatures) => signatures |
| 258 | + .iter() |
| 259 | + .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok()) |
| 260 | + .flatten() |
| 261 | + .collect::<Vec<_>>(), |
| 262 | + _ => get_valid_types(signature, current_types)?, |
| 263 | + }; |
| 264 | + |
| 265 | + Ok(valid_types) |
| 266 | +} |
| 267 | + |
208 | 268 | /// Returns a Vec of all possible valid argument types for the given signature.
|
209 | 269 | fn get_valid_types(
|
210 | 270 | signature: &TypeSignature,
|
|
0 commit comments