@@ -23,6 +23,7 @@ use arrow::{
23
23
} ;
24
24
use datafusion_common:: {
25
25
exec_err, internal_datafusion_err, internal_err, plan_err,
26
+ types:: { LogicalType , NativeType } ,
26
27
utils:: { coerced_fixed_size_list_to_list, list_ndims} ,
27
28
Result ,
28
29
} ;
@@ -395,40 +396,56 @@ fn get_valid_types(
395
396
}
396
397
}
397
398
399
+ fn function_length_check ( length : usize , expected_length : usize ) -> Result < ( ) > {
400
+ if length < 1 {
401
+ return plan_err ! (
402
+ "The signature expected at least one argument but received {expected_length}"
403
+ ) ;
404
+ }
405
+
406
+ if length != expected_length {
407
+ return plan_err ! (
408
+ "The signature expected {length} arguments but received {expected_length}"
409
+ ) ;
410
+ }
411
+
412
+ Ok ( ( ) )
413
+ }
414
+
398
415
let valid_types = match signature {
399
416
TypeSignature :: Variadic ( valid_types) => valid_types
400
417
. iter ( )
401
418
. map ( |valid_type| current_types. iter ( ) . map ( |_| valid_type. clone ( ) ) . collect ( ) )
402
419
. collect ( ) ,
403
420
TypeSignature :: String ( number) => {
404
- if * number < 1 {
405
- return plan_err ! (
406
- "The signature expected at least one argument but received {}" ,
407
- current_types. len( )
408
- ) ;
409
- }
410
- if * number != current_types. len ( ) {
411
- return plan_err ! (
412
- "The signature expected {} arguments but received {}" ,
413
- number,
414
- current_types. len( )
415
- ) ;
421
+ function_length_check ( current_types. len ( ) , * number) ?;
422
+
423
+ let mut new_types = Vec :: with_capacity ( current_types. len ( ) ) ;
424
+ for data_type in current_types. iter ( ) {
425
+ let logical_data_type: NativeType = data_type. into ( ) ;
426
+ if logical_data_type == NativeType :: String {
427
+ new_types. push ( data_type. to_owned ( ) ) ;
428
+ } else if logical_data_type == NativeType :: Null {
429
+ // TODO: Switch to Utf8View if all the string functions supports Utf8View
430
+ new_types. push ( DataType :: Utf8 ) ;
431
+ } else {
432
+ return plan_err ! (
433
+ "The signature expected NativeType::String but received {logical_data_type}"
434
+ ) ;
435
+ }
416
436
}
417
437
418
- fn coercion_rule (
438
+ // Find the common string type for the given types
439
+ fn find_common_type (
419
440
lhs_type : & DataType ,
420
441
rhs_type : & DataType ,
421
442
) -> Result < DataType > {
422
443
match ( lhs_type, rhs_type) {
423
- ( DataType :: Null , DataType :: Null ) => Ok ( DataType :: Utf8 ) ,
424
- ( DataType :: Null , data_type) | ( data_type, DataType :: Null ) => {
425
- coercion_rule ( data_type, & DataType :: Utf8 )
426
- }
427
444
( DataType :: Dictionary ( _, lhs) , DataType :: Dictionary ( _, rhs) ) => {
428
- coercion_rule ( lhs, rhs)
445
+ find_common_type ( lhs, rhs)
429
446
}
430
447
( DataType :: Dictionary ( _, v) , other)
431
- | ( other, DataType :: Dictionary ( _, v) ) => coercion_rule ( v, other) ,
448
+ | ( other, DataType :: Dictionary ( _, v) ) => find_common_type ( v, other) ,
432
449
_ => {
433
450
if let Some ( coerced_type) = string_coercion ( lhs_type, rhs_type) {
434
451
Ok ( coerced_type)
@@ -444,15 +461,13 @@ fn get_valid_types(
444
461
}
445
462
446
463
// Length checked above, safe to unwrap
447
- let mut coerced_type = current_types . first ( ) . unwrap ( ) . to_owned ( ) ;
448
- for t in current_types . iter ( ) . skip ( 1 ) {
449
- coerced_type = coercion_rule ( & coerced_type, t) ?;
464
+ let mut coerced_type = new_types . first ( ) . unwrap ( ) . to_owned ( ) ;
465
+ for t in new_types . iter ( ) . skip ( 1 ) {
466
+ coerced_type = find_common_type ( & coerced_type, t) ?;
450
467
}
451
468
452
469
fn base_type_or_default_type ( data_type : & DataType ) -> DataType {
453
- if data_type. is_null ( ) {
454
- DataType :: Utf8
455
- } else if let DataType :: Dictionary ( _, v) = data_type {
470
+ if let DataType :: Dictionary ( _, v) = data_type {
456
471
base_type_or_default_type ( v)
457
472
} else {
458
473
data_type. to_owned ( )
@@ -462,22 +477,22 @@ fn get_valid_types(
462
477
vec ! [ vec![ base_type_or_default_type( & coerced_type) ; * number] ]
463
478
}
464
479
TypeSignature :: Numeric ( number) => {
465
- if * number < 1 {
466
- return plan_err ! (
467
- "The signature expected at least one argument but received {}" ,
468
- current_types. len( )
469
- ) ;
470
- }
471
- if * number != current_types. len ( ) {
472
- return plan_err ! (
473
- "The signature expected {} arguments but received {}" ,
474
- number,
475
- current_types. len( )
476
- ) ;
477
- }
480
+ function_length_check ( current_types. len ( ) , * number) ?;
478
481
479
- let mut valid_type = current_types. first ( ) . unwrap ( ) . clone ( ) ;
482
+ // Find common numeric type amongs given types except string
483
+ let mut valid_type = current_types. first ( ) . unwrap ( ) . to_owned ( ) ;
480
484
for t in current_types. iter ( ) . skip ( 1 ) {
485
+ let logical_data_type: NativeType = t. into ( ) ;
486
+ if logical_data_type == NativeType :: Null {
487
+ continue ;
488
+ }
489
+
490
+ if !logical_data_type. is_numeric ( ) {
491
+ return plan_err ! (
492
+ "The signature expected NativeType::Numeric but received {logical_data_type}"
493
+ ) ;
494
+ }
495
+
481
496
if let Some ( coerced_type) = binary_numeric_coercion ( & valid_type, t) {
482
497
valid_type = coerced_type;
483
498
} else {
@@ -489,31 +504,55 @@ fn get_valid_types(
489
504
}
490
505
}
491
506
507
+ let logical_data_type: NativeType = valid_type. clone ( ) . into ( ) ;
508
+ // Fallback to default type if we don't know which type to coerced to
509
+ // f64 is chosen since most of the math functions utilize Signature::numeric,
510
+ // and their default type is double precision
511
+ if logical_data_type == NativeType :: Null {
512
+ valid_type = DataType :: Float64 ;
513
+ }
514
+
492
515
vec ! [ vec![ valid_type; * number] ]
493
516
}
494
517
TypeSignature :: Coercible ( target_types) => {
495
- if target_types. is_empty ( ) {
496
- return plan_err ! (
497
- "The signature expected at least one argument but received {}" ,
498
- current_types. len( )
499
- ) ;
500
- }
501
- if target_types. len ( ) != current_types. len ( ) {
502
- return plan_err ! (
503
- "The signature expected {} arguments but received {}" ,
504
- target_types. len( ) ,
505
- current_types. len( )
506
- ) ;
518
+ function_length_check ( current_types. len ( ) , target_types. len ( ) ) ?;
519
+
520
+ // Aim to keep this logic as SIMPLE as possible!
521
+ // Make sure the corresponding test is covered
522
+ // If this function becomes COMPLEX, create another new signature!
523
+ fn can_coerce_to (
524
+ logical_type : & NativeType ,
525
+ target_type : & NativeType ,
526
+ ) -> bool {
527
+ if logical_type == target_type {
528
+ return true ;
529
+ }
530
+
531
+ if logical_type == & NativeType :: Null {
532
+ return true ;
533
+ }
534
+
535
+ if target_type. is_integer ( ) && logical_type. is_integer ( ) {
536
+ return true ;
537
+ }
538
+
539
+ false
507
540
}
508
541
509
- for ( data_type, target_type) in current_types. iter ( ) . zip ( target_types. iter ( ) )
542
+ let mut new_types = Vec :: with_capacity ( current_types. len ( ) ) ;
543
+ for ( current_type, target_type) in
544
+ current_types. iter ( ) . zip ( target_types. iter ( ) )
510
545
{
511
- if !can_cast_types ( data_type, target_type) {
512
- return plan_err ! ( "{data_type} is not coercible to {target_type}" ) ;
546
+ let logical_type: NativeType = current_type. into ( ) ;
547
+ let target_logical_type = target_type. native ( ) ;
548
+ if can_coerce_to ( & logical_type, target_logical_type) {
549
+ let target_type =
550
+ target_logical_type. default_cast_for ( current_type) ?;
551
+ new_types. push ( target_type) ;
513
552
}
514
553
}
515
554
516
- vec ! [ target_types . to_owned ( ) ]
555
+ vec ! [ new_types ]
517
556
}
518
557
TypeSignature :: Uniform ( number, valid_types) => valid_types
519
558
. iter ( )
0 commit comments