Skip to content

Commit 7e0fc14

Browse files
authored
Fix get_type for higher-order array functions (#13756)
* Fix get_type for higher-order array functions * Fix recursive flatten The fix is covered by recursive flatten test case in array.slt * Restore "keep LargeList" in Array signature * clarify naming in the test
1 parent 5500b11 commit 7e0fc14

File tree

4 files changed

+116
-3
lines changed

4 files changed

+116
-3
lines changed

datafusion/expr-common/src/signature.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ pub enum ArrayFunctionSignature {
204204
/// The function takes a single argument that must be a List/LargeList/FixedSizeList
205205
/// or something that can be coerced to one of those types.
206206
Array,
207+
/// A function takes a single argument that must be a List/LargeList/FixedSizeList
208+
/// which gets coerced to List, with element type recursively coerced to List too if it is list-like.
209+
RecursiveArray,
207210
/// Specialized Signature for MapArray
208211
/// The function takes a single argument that must be a MapArray
209212
MapArray,
@@ -227,6 +230,9 @@ impl Display for ArrayFunctionSignature {
227230
ArrayFunctionSignature::Array => {
228231
write!(f, "array")
229232
}
233+
ArrayFunctionSignature::RecursiveArray => {
234+
write!(f, "recursive_array")
235+
}
230236
ArrayFunctionSignature::MapArray => {
231237
write!(f, "map_array")
232238
}

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ use arrow::{
2121
compute::can_cast_types,
2222
datatypes::{DataType, TimeUnit},
2323
};
24+
use datafusion_common::utils::coerced_fixed_size_list_to_list;
2425
use datafusion_common::{
2526
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
2627
types::{LogicalType, NativeType},
27-
utils::{coerced_fixed_size_list_to_list, list_ndims},
28+
utils::list_ndims,
2829
Result,
2930
};
3031
use datafusion_expr_common::{
@@ -418,7 +419,16 @@ fn get_valid_types(
418419
_ => Ok(vec![vec![]]),
419420
}
420421
}
422+
421423
fn array(array_type: &DataType) -> Option<DataType> {
424+
match array_type {
425+
DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
426+
DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
427+
_ => None,
428+
}
429+
}
430+
431+
fn recursive_array(array_type: &DataType) -> Option<DataType> {
422432
match array_type {
423433
DataType::List(_)
424434
| DataType::LargeList(_)
@@ -687,6 +697,13 @@ fn get_valid_types(
687697
array(&current_types[0])
688698
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
689699
}
700+
ArrayFunctionSignature::RecursiveArray => {
701+
if current_types.len() != 1 {
702+
return Ok(vec![vec![]]);
703+
}
704+
recursive_array(&current_types[0])
705+
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
706+
}
690707
ArrayFunctionSignature::MapArray => {
691708
if current_types.len() != 1 {
692709
return Ok(vec![vec![]]);

datafusion/functions-nested/src/extract.rs

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,3 +993,86 @@ where
993993
let data = mutable.freeze();
994994
Ok(arrow::array::make_array(data))
995995
}
996+
997+
#[cfg(test)]
998+
mod tests {
999+
use super::array_element_udf;
1000+
use arrow_schema::{DataType, Field};
1001+
use datafusion_common::{Column, DFSchema, ScalarValue};
1002+
use datafusion_expr::expr::ScalarFunction;
1003+
use datafusion_expr::{cast, Expr, ExprSchemable};
1004+
use std::collections::HashMap;
1005+
1006+
// Regression test for https://github.com/apache/datafusion/issues/13755
1007+
#[test]
1008+
fn test_array_element_return_type_fixed_size_list() {
1009+
let fixed_size_list_type = DataType::FixedSizeList(
1010+
Field::new("some_arbitrary_test_field", DataType::Int32, false).into(),
1011+
13,
1012+
);
1013+
let array_type = DataType::List(
1014+
Field::new_list_field(fixed_size_list_type.clone(), true).into(),
1015+
);
1016+
let index_type = DataType::Int64;
1017+
1018+
let schema = DFSchema::from_unqualified_fields(
1019+
vec![
1020+
Field::new("my_array", array_type.clone(), false),
1021+
Field::new("my_index", index_type.clone(), false),
1022+
]
1023+
.into(),
1024+
HashMap::default(),
1025+
)
1026+
.unwrap();
1027+
1028+
let udf = array_element_udf();
1029+
1030+
// ScalarUDFImpl::return_type
1031+
assert_eq!(
1032+
udf.return_type(&[array_type.clone(), index_type.clone()])
1033+
.unwrap(),
1034+
fixed_size_list_type
1035+
);
1036+
1037+
// ScalarUDFImpl::return_type_from_exprs with typed exprs
1038+
assert_eq!(
1039+
udf.return_type_from_exprs(
1040+
&[
1041+
cast(Expr::Literal(ScalarValue::Null), array_type.clone()),
1042+
cast(Expr::Literal(ScalarValue::Null), index_type.clone()),
1043+
],
1044+
&schema,
1045+
&[array_type.clone(), index_type.clone()]
1046+
)
1047+
.unwrap(),
1048+
fixed_size_list_type
1049+
);
1050+
1051+
// ScalarUDFImpl::return_type_from_exprs with exprs not carrying type
1052+
assert_eq!(
1053+
udf.return_type_from_exprs(
1054+
&[
1055+
Expr::Column(Column::new_unqualified("my_array")),
1056+
Expr::Column(Column::new_unqualified("my_index")),
1057+
],
1058+
&schema,
1059+
&[array_type.clone(), index_type.clone()]
1060+
)
1061+
.unwrap(),
1062+
fixed_size_list_type
1063+
);
1064+
1065+
// Via ExprSchemable::get_type (e.g. SimplifyInfo)
1066+
let udf_expr = Expr::ScalarFunction(ScalarFunction {
1067+
func: array_element_udf(),
1068+
args: vec![
1069+
Expr::Column(Column::new_unqualified("my_array")),
1070+
Expr::Column(Column::new_unqualified("my_index")),
1071+
],
1072+
});
1073+
assert_eq!(
1074+
ExprSchemable::get_type(&udf_expr, &schema).unwrap(),
1075+
fixed_size_list_type
1076+
);
1077+
}
1078+
}

datafusion/functions-nested/src/flatten.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ use datafusion_common::cast::{
2828
use datafusion_common::{exec_err, Result};
2929
use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY;
3030
use datafusion_expr::{
31-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
31+
ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature,
32+
TypeSignature, Volatility,
3233
};
3334
use std::any::Any;
3435
use std::sync::{Arc, OnceLock};
@@ -56,7 +57,13 @@ impl Default for Flatten {
5657
impl Flatten {
5758
pub fn new() -> Self {
5859
Self {
59-
signature: Signature::array(Volatility::Immutable),
60+
signature: Signature {
61+
// TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive
62+
type_signature: TypeSignature::ArraySignature(
63+
ArrayFunctionSignature::RecursiveArray,
64+
),
65+
volatility: Volatility::Immutable,
66+
},
6067
aliases: vec![],
6168
}
6269
}

0 commit comments

Comments
 (0)