Skip to content

Commit 8aafa54

Browse files
authored
Apply type_union_resolution to array and values (#12753)
* cleanup make array coercion rule Signed-off-by: jayzhan211 <[email protected]> * change to type union resolution Signed-off-by: jayzhan211 <[email protected]> * change value too Signed-off-by: jayzhan211 <[email protected]> * fix tpyo Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent cf76aba commit 8aafa54

File tree

9 files changed

+77
-98
lines changed

9 files changed

+77
-98
lines changed

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -471,10 +471,16 @@ fn type_union_resolution_coercion(
471471
let new_value_type = type_union_resolution_coercion(value_type, other_type);
472472
new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t)))
473473
}
474+
(DataType::List(lhs), DataType::List(rhs)) => {
475+
let new_item_type =
476+
type_union_resolution_coercion(lhs.data_type(), rhs.data_type());
477+
new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true))))
478+
}
474479
_ => {
475480
// numeric coercion is the same as comparison coercion, both find the narrowest type
476481
// that can accommodate both types
477482
binary_numeric_coercion(lhs_type, rhs_type)
483+
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
478484
.or_else(|| string_coercion(lhs_type, rhs_type))
479485
.or_else(|| numeric_string_coercion(lhs_type, rhs_type))
480486
}
@@ -507,22 +513,6 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
507513
.or_else(|| struct_coercion(lhs_type, rhs_type))
508514
}
509515

510-
/// Coerce `lhs_type` and `rhs_type` to a common type for `VALUES` expression
511-
///
512-
/// For example `VALUES (1, 2), (3.0, 4.0)` where the first row is `Int32` and
513-
/// the second row is `Float64` will coerce to `Float64`
514-
///
515-
pub fn values_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
516-
if lhs_type == rhs_type {
517-
// same type => equality is possible
518-
return Some(lhs_type.clone());
519-
}
520-
binary_numeric_coercion(lhs_type, rhs_type)
521-
.or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type))
522-
.or_else(|| string_coercion(lhs_type, rhs_type))
523-
.or_else(|| binary_coercion(lhs_type, rhs_type))
524-
}
525-
526516
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
527517
/// where one is numeric and one is `Utf8`/`LargeUtf8`.
528518
fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ use crate::logical_plan::{
3535
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
3636
Window,
3737
};
38-
use crate::type_coercion::binary::values_coercion;
3938
use crate::utils::{
4039
can_hash, columnize_expr, compare_sort_expr, expr_to_columns,
4140
find_valid_equijoin_key_pair, group_window_expr_by_sort_keys,
@@ -53,6 +52,7 @@ use datafusion_common::{
5352
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
5453
TableReference, ToDFSchema, UnnestOptions,
5554
};
55+
use datafusion_expr_common::type_coercion::binary::type_union_resolution;
5656

5757
use super::dml::InsertOp;
5858
use super::plan::{ColumnUnnestList, ColumnUnnestType};
@@ -209,7 +209,8 @@ impl LogicalPlanBuilder {
209209
}
210210
if let Some(prev_type) = common_type {
211211
// get common type of each column values.
212-
let Some(new_type) = values_coercion(&data_type, &prev_type) else {
212+
let data_types = vec![prev_type.clone(), data_type.clone()];
213+
let Some(new_type) = type_union_resolution(&data_types) else {
213214
return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}");
214215
};
215216
common_type = Some(new_type);

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,20 @@ pub fn data_types(
167167
try_coerce_types(valid_types, current_types, &signature.type_signature)
168168
}
169169

170+
fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
171+
if let TypeSignature::OneOf(signatures) = type_signature {
172+
return signatures.iter().all(is_well_supported_signature);
173+
}
174+
175+
matches!(
176+
type_signature,
177+
TypeSignature::UserDefined
178+
| TypeSignature::Numeric(_)
179+
| TypeSignature::Coercible(_)
180+
| TypeSignature::Any(_)
181+
)
182+
}
183+
170184
fn try_coerce_types(
171185
valid_types: Vec<Vec<DataType>>,
172186
current_types: &[DataType],
@@ -175,14 +189,7 @@ fn try_coerce_types(
175189
let mut valid_types = valid_types;
176190

177191
// Well-supported signature that returns exact valid types.
178-
if !valid_types.is_empty()
179-
&& matches!(
180-
type_signature,
181-
TypeSignature::UserDefined
182-
| TypeSignature::Numeric(_)
183-
| TypeSignature::Coercible(_)
184-
)
185-
{
192+
if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
186193
// exact valid types
187194
assert_eq!(valid_types.len(), 1);
188195
let valid_types = valid_types.swap_remove(0);

datafusion/functions-nested/src/make_array.rs

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! [`ScalarUDFImpl`] definitions for `make_array` function.
1919
20+
use std::vec;
2021
use std::{any::Any, sync::Arc};
2122

2223
use arrow::array::{ArrayData, Capacities, MutableArrayData};
@@ -26,9 +27,8 @@ use arrow_array::{
2627
use arrow_buffer::OffsetBuffer;
2728
use arrow_schema::DataType::{LargeList, List, Null};
2829
use arrow_schema::{DataType, Field};
29-
use datafusion_common::internal_err;
3030
use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result};
31-
use datafusion_expr::type_coercion::binary::comparison_coercion;
31+
use datafusion_expr::binary::type_union_resolution;
3232
use datafusion_expr::TypeSignature;
3333
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
3434

@@ -82,19 +82,12 @@ impl ScalarUDFImpl for MakeArray {
8282
match arg_types.len() {
8383
0 => Ok(empty_array_type()),
8484
_ => {
85-
let mut expr_type = DataType::Null;
86-
for arg_type in arg_types {
87-
if !arg_type.equals_datatype(&DataType::Null) {
88-
expr_type = arg_type.clone();
89-
break;
90-
}
91-
}
92-
93-
if expr_type.is_null() {
94-
expr_type = DataType::Int64;
95-
}
96-
97-
Ok(List(Arc::new(Field::new("item", expr_type, true))))
85+
// At this point, all the type in array should be coerced to the same one
86+
Ok(List(Arc::new(Field::new(
87+
"item",
88+
arg_types[0].to_owned(),
89+
true,
90+
))))
9891
}
9992
}
10093
}
@@ -112,22 +105,21 @@ impl ScalarUDFImpl for MakeArray {
112105
}
113106

114107
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
115-
let new_type = arg_types.iter().skip(1).try_fold(
116-
arg_types.first().unwrap().clone(),
117-
|acc, x| {
118-
// The coerced types found by `comparison_coercion` are not guaranteed to be
119-
// coercible for the arguments. `comparison_coercion` returns more loose
120-
// types that can be coerced to both `acc` and `x` for comparison purpose.
121-
// See `maybe_data_types` for the actual coercion.
122-
let coerced_type = comparison_coercion(&acc, x);
123-
if let Some(coerced_type) = coerced_type {
124-
Ok(coerced_type)
125-
} else {
126-
internal_err!("Coercion from {acc:?} to {x:?} failed.")
127-
}
128-
},
129-
)?;
130-
Ok(vec![new_type; arg_types.len()])
108+
if let Some(new_type) = type_union_resolution(arg_types) {
109+
if let DataType::FixedSizeList(field, _) = new_type {
110+
Ok(vec![DataType::List(field); arg_types.len()])
111+
} else if new_type.is_null() {
112+
Ok(vec![DataType::Int64; arg_types.len()])
113+
} else {
114+
Ok(vec![new_type; arg_types.len()])
115+
}
116+
} else {
117+
plan_err!(
118+
"Fail to find the valid type between {:?} for {}",
119+
arg_types,
120+
self.name()
121+
)
122+
}
131123
}
132124
}
133125

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
456456
self.schema,
457457
&func,
458458
)?;
459-
let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?;
460459
Ok(Transformed::yes(Expr::ScalarFunction(
461460
ScalarFunction::new_udf(func, new_expr),
462461
)))
@@ -756,30 +755,6 @@ fn coerce_arguments_for_signature_with_aggregate_udf(
756755
.collect()
757756
}
758757

759-
fn coerce_arguments_for_fun(
760-
expressions: Vec<Expr>,
761-
schema: &DFSchema,
762-
fun: &Arc<ScalarUDF>,
763-
) -> Result<Vec<Expr>> {
764-
// Cast Fixedsizelist to List for array functions
765-
if fun.name() == "make_array" {
766-
expressions
767-
.into_iter()
768-
.map(|expr| {
769-
let data_type = expr.get_type(schema).unwrap();
770-
if let DataType::FixedSizeList(field, _) = data_type {
771-
let to_type = DataType::List(Arc::clone(&field));
772-
expr.cast_to(&to_type, schema)
773-
} else {
774-
Ok(expr)
775-
}
776-
})
777-
.collect()
778-
} else {
779-
Ok(expressions)
780-
}
781-
}
782-
783758
fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
784759
// Given expressions like:
785760
//

datafusion/sqllogictest/test_files/array.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6595,7 +6595,7 @@ select make_array(1, 2.0, null, 3)
65956595
query ?
65966596
select make_array(1.0, '2', null)
65976597
----
6598-
[1.0, 2, ]
6598+
[1.0, 2.0, ]
65996599

66006600
### FixedSizeListArray
66016601

datafusion/sqllogictest/test_files/errors.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,5 +128,5 @@ from aggregate_test_100
128128
order by c9
129129

130130

131-
statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8
131+
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type
132132
create table foo as values (1), ('foo');

datafusion/sqllogictest/test_files/map.slt

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,18 +148,17 @@ SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']);
148148
{[1, 2]: [a, b], [3, 4]: [b]}
149149

150150
query ?
151-
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
151+
SELECT MAKE_MAP('POST', 41, 'HEAD', 53, 'PATCH', 30);
152152
----
153-
{POST: 41, HEAD: ab, PATCH: 30}
153+
{POST: 41, HEAD: 53, PATCH: 30}
154+
155+
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'ab' to value of Int64 type
156+
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
154157

158+
# Map keys can not be NULL
155159
query error
156160
SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30);
157161

158-
query ?
159-
SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30);
160-
----
161-
{POST: 41, HEAD: ab, PATCH: 30}
162-
163162
query ?
164163
SELECT MAKE_MAP()
165164
----
@@ -517,9 +516,12 @@ query error
517516
SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL];
518517

519518
query ?
520-
SELECT MAP { 'a': 1, 2: 3 };
519+
SELECT MAP { 'a': 1, 'b': 3 };
521520
----
522-
{a: 1, 2: 3}
521+
{a: 1, b: 3}
522+
523+
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
524+
SELECT MAP { 'a': 1, 2: 3 };
523525

524526
# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key
525527
# query ?
@@ -610,9 +612,12 @@ select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7)
610612
# Tests for map_keys
611613

612614
query ?
613-
SELECT map_keys(MAP { 'a': 1, 2: 3 });
615+
SELECT map_keys(MAP { 'a': 1, 'b': 3 });
614616
----
615-
[a, 2]
617+
[a, b]
618+
619+
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
620+
SELECT map_keys(MAP { 'a': 1, 2: 3 });
616621

617622
query ?
618623
SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t;
@@ -657,8 +662,11 @@ SELECT map_keys(column1) from map_array_table_1;
657662

658663
# Tests for map_values
659664

660-
query ?
665+
query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type
661666
SELECT map_values(MAP { 'a': 1, 2: 3 });
667+
668+
query ?
669+
SELECT map_values(MAP { 'a': 1, 'b': 3 });
662670
----
663671
[1, 3]
664672

datafusion/sqllogictest/test_files/select.slt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,17 +348,23 @@ VALUES (1),()
348348
statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1
349349
VALUES (1),(1,2)
350350

351-
statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0
351+
query I
352352
VALUES (1),('2')
353+
----
354+
1
355+
2
353356

354357
query R
355358
VALUES (1),(2.0)
356359
----
357360
1
358361
2
359362

360-
statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 1
363+
query II
361364
VALUES (1,2), (1,'2')
365+
----
366+
1 2
367+
1 2
362368

363369
query IT
364370
VALUES (1,'a'),(NULL,'b'),(3,'c')

0 commit comments

Comments
 (0)