Skip to content

Commit 8506795

Browse files
committed
fix numeric_coercion
Signed-off-by: jayzhan211 <[email protected]>
1 parent 06bbe12 commit 8506795

File tree

2 files changed

+90
-46
lines changed

2 files changed

+90
-46
lines changed

datafusion/expr/src/type_coercion/binary.rs

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
290290
// same type => equality is possible
291291
return Some(lhs_type.clone());
292292
}
293-
comparison_binary_numeric_coercion(lhs_type, rhs_type)
293+
294+
numeric_coercion(lhs_type, rhs_type)
294295
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
295296
.or_else(|| temporal_coercion(lhs_type, rhs_type))
296297
.or_else(|| string_coercion(lhs_type, rhs_type))
@@ -354,72 +355,103 @@ fn string_temporal_coercion(
354355
}
355356

356357
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
357-
/// where one both are numeric
358-
fn comparison_binary_numeric_coercion(
359-
lhs_type: &DataType,
360-
rhs_type: &DataType,
361-
) -> Option<DataType> {
358+
/// where both are numeric and the coerced type MAY not be the same as either input type.
359+
pub fn numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
362360
use arrow::datatypes::DataType::*;
363361
if !lhs_type.is_numeric() || !rhs_type.is_numeric() {
364362
return None;
365363
};
366364

367-
// same type => all good
368-
if lhs_type == rhs_type {
369-
return Some(lhs_type.clone());
370-
}
371-
372365
// these are ordered from most informative to least informative so
373366
// that the coercion does not lose information via truncation
374367
match (lhs_type, rhs_type) {
375-
// Prefer decimal data type over floating point for comparison operation
376368
(Decimal128(_, _), Decimal128(_, _)) => {
377369
get_wider_decimal_type(lhs_type, rhs_type)
378370
}
379371
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
380-
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
381372
(Decimal256(_, _), Decimal256(_, _)) => {
382373
get_wider_decimal_type(lhs_type, rhs_type)
383374
}
384375
(Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type),
385-
(_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type),
376+
377+
// f64
378+
// Prefer f64 over u64 and i64, data lossy is expected
386379
(Float64, _) | (_, Float64) => Some(Float64),
387-
(_, Float32) | (Float32, _) => Some(Float32),
388-
// The following match arms encode the following logic: Given the two
389-
// integral types, we choose the narrowest possible integral type that
390-
// accommodates all values of both types. Note that some information
391-
// loss is inevitable when we have a signed type and a `UInt64`, in
392-
// which case we use `Int64`;i.e. the widest signed integral type.
393-
(Int64, _)
394-
| (_, Int64)
395-
| (UInt64, Int8)
396-
| (Int8, UInt64)
397-
| (UInt64, Int16)
398-
| (Int16, UInt64)
399-
| (UInt64, Int32)
400-
| (Int32, UInt64)
401-
| (UInt32, Int8)
402-
| (Int8, UInt32)
403-
| (UInt32, Int16)
404-
| (Int16, UInt32)
405-
| (UInt32, Int32)
406-
| (Int32, UInt32) => Some(Int64),
407-
(UInt64, _) | (_, UInt64) => Some(UInt64),
408-
(Int32, _)
409-
| (_, Int32)
410-
| (UInt16, Int16)
411-
| (Int16, UInt16)
412-
| (UInt16, Int8)
413-
| (Int8, UInt16) => Some(Int32),
414-
(UInt32, _) | (_, UInt32) => Some(UInt32),
415-
(Int16, _) | (_, Int16) | (Int8, UInt8) | (UInt8, Int8) => Some(Int16),
416-
(UInt16, _) | (_, UInt16) => Some(UInt16),
417-
(Int8, _) | (_, Int8) => Some(Int8),
418-
(UInt8, _) | (_, UInt8) => Some(UInt8),
380+
381+
// u64
382+
// Prefer f64 over u64, data lossy is expected
383+
(UInt64, Float32) | (Float32, UInt64) | (UInt64, Float16) | (Float16, UInt64) => {
384+
Some(Float64)
385+
}
386+
// Prefer i64 over u64, data lossy is expected
387+
(UInt64, data_type) | (data_type, UInt64) => {
388+
if data_type.is_signed_integer() {
389+
Some(Int64)
390+
} else {
391+
Some(UInt64)
392+
}
393+
}
394+
395+
// i64
396+
// Prefer f64 over i64, data lossy is expected
397+
(Int64, Float32) | (Float32, Int64) | (Int64, Float16) | (Float16, Int64) => {
398+
Some(Float64)
399+
}
400+
(Int64, _) | (_, Int64) => Some(Int64),
401+
402+
// f32
403+
(Float32, _) | (_, Float32) => Some(Float32),
404+
405+
// u32
406+
(UInt32, Float16) | (Float16, UInt32) => Some(Float64),
407+
(UInt32, data_type) | (data_type, UInt32) => {
408+
if data_type.is_signed_integer() {
409+
Some(Int64)
410+
} else {
411+
Some(UInt32)
412+
}
413+
}
414+
415+
// i32
416+
// f32 is not guaranteed to be able to represent all i32 values
417+
(Int32, Float16) | (Float16, Int32) => Some(Float64),
418+
(Int32, _) | (_, Int32) => Some(Int32),
419+
420+
// f16
421+
(Float16, UInt16) | (UInt16, Float16) | (Float16, Int16) | (Int16, Float16) => {
422+
Some(Float32)
423+
}
424+
(Float16, _) | (_, Float16) => Some(Float16),
425+
426+
// u16
427+
(UInt16, data_type) | (data_type, UInt16) => {
428+
if data_type.is_signed_integer() {
429+
Some(Int32)
430+
} else {
431+
Some(UInt16)
432+
}
433+
}
434+
435+
// i16
436+
(Int16, _) | (_, Int16) => Some(Int16),
437+
438+
// u8
439+
(UInt8, UInt8) => Some(UInt8),
440+
(UInt8, Int8) | (Int8, UInt8) => Some(Int16),
441+
442+
// i8
443+
(Int8, Int8) => Some(Int8),
444+
419445
_ => None,
420446
}
421447
}
422448

449+
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
450+
/// where both are numeric and the coerced type SHOULD be one of the input types.
451+
pub fn exact_numeric_coercion(_: &DataType, _: &DataType) -> Option<DataType> {
452+
todo!("exact_numeric_coercion")
453+
}
454+
423455
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
424456
/// a comparison operation where one is a decimal
425457
fn get_comparison_common_decimal_type(

datafusion/sqllogictest/test_files/array.slt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,6 +2975,18 @@ select make_array(1.0, '2', null)
29752975
----
29762976
[1.0, 2, ]
29772977

2978+
# make_array scalar function #5
2979+
query error
2980+
select
2981+
make_array(arrow_cast(1, 'Int8'), arrow_cast(18446744073709551610, 'UInt64')),
2982+
arrow_typeof(make_array(arrow_cast(1, 'Int8'), arrow_cast(18446744073709551610, 'UInt64')))
2983+
;
2984+
----
2985+
DataFusion error: Optimizer rule 'simplify_expressions' failed
2986+
caused by
2987+
Arrow error: Cast error: Can't cast value 18446744073709551610 to type Int64
2988+
2989+
29782990
### FixedSizeListArray
29792991

29802992
statement ok

0 commit comments

Comments
 (0)