Skip to content

Commit edbdefe

Browse files
authored
Support array_concat for Utf8View (#14378)
* Add tests for concatenating differnet string types * clean up code * fmt
1 parent 67bc04c commit edbdefe

File tree

3 files changed

+69
-70
lines changed

3 files changed

+69
-70
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ use arrow::datatypes::{
3030
};
3131
use datafusion_common::types::NativeType;
3232
use datafusion_common::{
33-
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err,
34-
Diagnostic, Result, Span, Spans,
33+
exec_err, internal_err, plan_datafusion_err, plan_err, Diagnostic, Result, Span,
34+
Spans,
3535
};
3636
use itertools::Itertools;
3737

@@ -928,54 +928,6 @@ fn get_wider_decimal_type(
928928
}
929929
}
930930

931-
/// Returns the wider type among arguments `lhs` and `rhs`.
932-
/// The wider type is the type that can safely represent values from both types
933-
/// without information loss. Returns an Error if types are incompatible.
934-
pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> {
935-
use arrow::datatypes::DataType::*;
936-
Ok(match (lhs, rhs) {
937-
(lhs, rhs) if lhs == rhs => lhs.clone(),
938-
// Right UInt is larger than left UInt.
939-
(UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) |
940-
// Right Int is larger than left Int.
941-
(Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) |
942-
// Right Float is larger than left Float.
943-
(Float16, Float32 | Float64) | (Float32, Float64) |
944-
// Right String is larger than left String.
945-
(Utf8, LargeUtf8) |
946-
// Any right type is wider than a left hand side Null.
947-
(Null, _) => rhs.clone(),
948-
// Left UInt is larger than right UInt.
949-
(UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) |
950-
// Left Int is larger than right Int.
951-
(Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) |
952-
// Left Float is larger than right Float.
953-
(Float32 | Float64, Float16) | (Float64, Float32) |
954-
// Left String is larger than right String.
955-
(LargeUtf8, Utf8) |
956-
// Any left type is wider than a right hand side Null.
957-
(_, Null) => lhs.clone(),
958-
(List(lhs_field), List(rhs_field)) => {
959-
let field_type =
960-
get_wider_type(lhs_field.data_type(), rhs_field.data_type())?;
961-
if lhs_field.name() != rhs_field.name() {
962-
return Err(exec_datafusion_err!(
963-
"There is no wider type that can represent both {lhs} and {rhs}."
964-
));
965-
}
966-
assert_eq!(lhs_field.name(), rhs_field.name());
967-
let field_name = lhs_field.name();
968-
let nullable = lhs_field.is_nullable() | rhs_field.is_nullable();
969-
List(Arc::new(Field::new(field_name, field_type, nullable)))
970-
}
971-
(_, _) => {
972-
return Err(exec_datafusion_err!(
973-
"There is no wider type that can represent both {lhs} and {rhs}."
974-
));
975-
}
976-
})
977-
}
978-
979931
/// Convert the numeric data type to the decimal data type.
980932
/// We support signed and unsigned integer types and floating-point type.
981933
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {

datafusion/functions-nested/src/concat.rs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ use datafusion_common::{
2929
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
3030
};
3131
use datafusion_expr::{
32-
type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl,
33-
Signature, Volatility,
32+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
3433
};
3534
use datafusion_macros::user_doc;
3635

@@ -276,25 +275,32 @@ impl ScalarUDFImpl for ArrayConcat {
276275
let mut expr_type = DataType::Null;
277276
let mut max_dims = 0;
278277
for arg_type in arg_types {
279-
match arg_type {
280-
DataType::List(field) => {
281-
if !field.data_type().equals_datatype(&DataType::Null) {
282-
let dims = list_ndims(arg_type);
283-
expr_type = match max_dims.cmp(&dims) {
284-
Ordering::Greater => expr_type,
285-
Ordering::Equal => get_wider_type(&expr_type, arg_type)?,
286-
Ordering::Less => {
287-
max_dims = dims;
288-
arg_type.clone()
289-
}
290-
};
278+
let DataType::List(field) = arg_type else {
279+
return plan_err!(
280+
"The array_concat function can only accept list as the args."
281+
);
282+
};
283+
if !field.data_type().equals_datatype(&DataType::Null) {
284+
let dims = list_ndims(arg_type);
285+
expr_type = match max_dims.cmp(&dims) {
286+
Ordering::Greater => expr_type,
287+
Ordering::Equal => {
288+
if expr_type == DataType::Null {
289+
arg_type.clone()
290+
} else if !expr_type.equals_datatype(arg_type) {
291+
return plan_err!(
292+
"It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type
293+
);
294+
} else {
295+
expr_type
296+
}
291297
}
292-
}
293-
_ => {
294-
return plan_err!(
295-
"The array_concat function can only accept list as the args."
296-
)
297-
}
298+
299+
Ordering::Less => {
300+
max_dims = dims;
301+
arg_type.clone()
302+
}
303+
};
298304
}
299305
}
300306

datafusion/sqllogictest/test_files/array.slt

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,47 @@ select array_concat([]);
28702870
----
28712871
[]
28722872

2873+
# Concatenating strings arrays
2874+
query ?
2875+
select array_concat(
2876+
['1', '2'],
2877+
['3']
2878+
);
2879+
----
2880+
[1, 2, 3]
2881+
2882+
# Concatenating string arrays
2883+
query ?
2884+
select array_concat(
2885+
[arrow_cast('1', 'LargeUtf8'), arrow_cast('2', 'LargeUtf8')],
2886+
[arrow_cast('3', 'LargeUtf8')]
2887+
);
2888+
----
2889+
[1, 2, 3]
2890+
2891+
# Concatenating stringview
2892+
query ?
2893+
select array_concat(
2894+
[arrow_cast('1', 'Utf8View'), arrow_cast('2', 'Utf8View')],
2895+
[arrow_cast('3', 'Utf8View')]
2896+
);
2897+
----
2898+
[1, 2, 3]
2899+
2900+
# Concatenating Mixed types (doesn't work)
2901+
query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
2902+
select array_concat(
2903+
[arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')],
2904+
[arrow_cast('3', 'LargeUtf8')]
2905+
);
2906+
2907+
# Concatenating Mixed types (doesn't work)
2908+
query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
2909+
select array_concat(
2910+
[arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')],
2911+
[arrow_cast('3', 'Utf8View')]
2912+
);
2913+
28732914
# array_concat error
28742915
query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\.
28752916
select array_concat(1, 2);

0 commit comments

Comments
 (0)