@@ -290,7 +290,8 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
290
290
// same type => equality is possible
291
291
return Some ( lhs_type. clone ( ) ) ;
292
292
}
293
- comparison_binary_numeric_coercion ( lhs_type, rhs_type)
293
+
294
+ numeric_coercion ( lhs_type, rhs_type)
294
295
. or_else ( || dictionary_coercion ( lhs_type, rhs_type, true ) )
295
296
. or_else ( || temporal_coercion ( lhs_type, rhs_type) )
296
297
. or_else ( || string_coercion ( lhs_type, rhs_type) )
@@ -354,72 +355,103 @@ fn string_temporal_coercion(
354
355
}
355
356
356
357
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
357
- /// where one both are numeric
358
- fn comparison_binary_numeric_coercion (
359
- lhs_type : & DataType ,
360
- rhs_type : & DataType ,
361
- ) -> Option < DataType > {
358
+ /// where both are numeric and the coerced type MAY not be the same as either input type.
359
+ pub fn numeric_coercion ( lhs_type : & DataType , rhs_type : & DataType ) -> Option < DataType > {
362
360
use arrow:: datatypes:: DataType :: * ;
363
361
if !lhs_type. is_numeric ( ) || !rhs_type. is_numeric ( ) {
364
362
return None ;
365
363
} ;
366
364
367
- // same type => all good
368
- if lhs_type == rhs_type {
369
- return Some ( lhs_type. clone ( ) ) ;
370
- }
371
-
372
365
// these are ordered from most informative to least informative so
373
366
// that the coercion does not lose information via truncation
374
367
match ( lhs_type, rhs_type) {
375
- // Prefer decimal data type over floating point for comparison operation
376
368
( Decimal128 ( _, _) , Decimal128 ( _, _) ) => {
377
369
get_wider_decimal_type ( lhs_type, rhs_type)
378
370
}
379
371
( Decimal128 ( _, _) , _) => get_comparison_common_decimal_type ( lhs_type, rhs_type) ,
380
- ( _, Decimal128 ( _, _) ) => get_comparison_common_decimal_type ( rhs_type, lhs_type) ,
381
372
( Decimal256 ( _, _) , Decimal256 ( _, _) ) => {
382
373
get_wider_decimal_type ( lhs_type, rhs_type)
383
374
}
384
375
( Decimal256 ( _, _) , _) => get_comparison_common_decimal_type ( lhs_type, rhs_type) ,
385
- ( _, Decimal256 ( _, _) ) => get_comparison_common_decimal_type ( rhs_type, lhs_type) ,
376
+
377
+ // f64
378
+ // Prefer f64 over u64 and i64, data lossy is expected
386
379
( Float64 , _) | ( _, Float64 ) => Some ( Float64 ) ,
387
- ( _, Float32 ) | ( Float32 , _) => Some ( Float32 ) ,
388
- // The following match arms encode the following logic: Given the two
389
- // integral types, we choose the narrowest possible integral type that
390
- // accommodates all values of both types. Note that some information
391
- // loss is inevitable when we have a signed type and a `UInt64`, in
392
- // which case we use `Int64`;i.e. the widest signed integral type.
393
- ( Int64 , _)
394
- | ( _, Int64 )
395
- | ( UInt64 , Int8 )
396
- | ( Int8 , UInt64 )
397
- | ( UInt64 , Int16 )
398
- | ( Int16 , UInt64 )
399
- | ( UInt64 , Int32 )
400
- | ( Int32 , UInt64 )
401
- | ( UInt32 , Int8 )
402
- | ( Int8 , UInt32 )
403
- | ( UInt32 , Int16 )
404
- | ( Int16 , UInt32 )
405
- | ( UInt32 , Int32 )
406
- | ( Int32 , UInt32 ) => Some ( Int64 ) ,
407
- ( UInt64 , _) | ( _, UInt64 ) => Some ( UInt64 ) ,
408
- ( Int32 , _)
409
- | ( _, Int32 )
410
- | ( UInt16 , Int16 )
411
- | ( Int16 , UInt16 )
412
- | ( UInt16 , Int8 )
413
- | ( Int8 , UInt16 ) => Some ( Int32 ) ,
414
- ( UInt32 , _) | ( _, UInt32 ) => Some ( UInt32 ) ,
415
- ( Int16 , _) | ( _, Int16 ) | ( Int8 , UInt8 ) | ( UInt8 , Int8 ) => Some ( Int16 ) ,
416
- ( UInt16 , _) | ( _, UInt16 ) => Some ( UInt16 ) ,
417
- ( Int8 , _) | ( _, Int8 ) => Some ( Int8 ) ,
418
- ( UInt8 , _) | ( _, UInt8 ) => Some ( UInt8 ) ,
380
+
381
+ // u64
382
+ // Prefer f64 over u64, data lossy is expected
383
+ ( UInt64 , Float32 ) | ( Float32 , UInt64 ) | ( UInt64 , Float16 ) | ( Float16 , UInt64 ) => {
384
+ Some ( Float64 )
385
+ }
386
+ // Prefer i64 over u64, data lossy is expected
387
+ ( UInt64 , data_type) | ( data_type, UInt64 ) => {
388
+ if data_type. is_signed_integer ( ) {
389
+ Some ( Int64 )
390
+ } else {
391
+ Some ( UInt64 )
392
+ }
393
+ }
394
+
395
+ // i64
396
+ // Prefer f64 over i64, data lossy is expected
397
+ ( Int64 , Float32 ) | ( Float32 , Int64 ) | ( Int64 , Float16 ) | ( Float16 , Int64 ) => {
398
+ Some ( Float64 )
399
+ }
400
+ ( Int64 , _) | ( _, Int64 ) => Some ( Int64 ) ,
401
+
402
+ // f32
403
+ ( Float32 , _) | ( _, Float32 ) => Some ( Float32 ) ,
404
+
405
+ // u32
406
+ ( UInt32 , Float16 ) | ( Float16 , UInt32 ) => Some ( Float64 ) ,
407
+ ( UInt32 , data_type) | ( data_type, UInt32 ) => {
408
+ if data_type. is_signed_integer ( ) {
409
+ Some ( Int64 )
410
+ } else {
411
+ Some ( UInt32 )
412
+ }
413
+ }
414
+
415
+ // i32
416
+ // f32 is not guaranteed to be able to represent all i32 values
417
+ ( Int32 , Float16 ) | ( Float16 , Int32 ) => Some ( Float64 ) ,
418
+ ( Int32 , _) | ( _, Int32 ) => Some ( Int32 ) ,
419
+
420
+ // f16
421
+ ( Float16 , UInt16 ) | ( UInt16 , Float16 ) | ( Float16 , Int16 ) | ( Int16 , Float16 ) => {
422
+ Some ( Float32 )
423
+ }
424
+ ( Float16 , _) | ( _, Float16 ) => Some ( Float16 ) ,
425
+
426
+ // u16
427
+ ( UInt16 , data_type) | ( data_type, UInt16 ) => {
428
+ if data_type. is_signed_integer ( ) {
429
+ Some ( Int32 )
430
+ } else {
431
+ Some ( UInt16 )
432
+ }
433
+ }
434
+
435
+ // i16
436
+ ( Int16 , _) | ( _, Int16 ) => Some ( Int16 ) ,
437
+
438
+ // u8
439
+ ( UInt8 , UInt8 ) => Some ( UInt8 ) ,
440
+ ( UInt8 , Int8 ) | ( Int8 , UInt8 ) => Some ( Int16 ) ,
441
+
442
+ // i8
443
+ ( Int8 , Int8 ) => Some ( Int8 ) ,
444
+
419
445
_ => None ,
420
446
}
421
447
}
422
448
449
+ /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
450
+ /// where both are numeric and the coerced type SHOULD be one of the input types.
451
+ pub fn exact_numeric_coercion ( _: & DataType , _: & DataType ) -> Option < DataType > {
452
+ todo ! ( "exact_numeric_coercion" )
453
+ }
454
+
423
455
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of
424
456
/// a comparison operation where one is a decimal
425
457
fn get_comparison_common_decimal_type (
0 commit comments