Skip to content

Commit 6686e03

Browse files
authored
Use LogicalType for TypeSignature Numeric and String, Coercible (#13240)
* use logical type for signature Signed-off-by: jayzhan211 <[email protected]> * fmt & clippy Signed-off-by: jayzhan211 <[email protected]> * numeric Signed-off-by: jayzhan211 <[email protected]> * fix numeric Signed-off-by: jayzhan211 <[email protected]> * deprecate coercible Signed-off-by: jayzhan211 <[email protected]> * introduce numeric and numeric string Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * add back coercible Signed-off-by: jayzhan211 <[email protected]> * rename Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * rm numeric string signature Signed-off-by: jayzhan211 <[email protected]> * typo Signed-off-by: jayzhan211 <[email protected]> * improve doc and err msg Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 345117b commit 6686e03

File tree

21 files changed

+199
-117
lines changed

21 files changed

+199
-117
lines changed

datafusion/common/src/types/logical.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ impl fmt::Debug for dyn LogicalType {
9898
}
9999
}
100100

101+
impl std::fmt::Display for dyn LogicalType {
102+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103+
write!(f, "{self:?}")
104+
}
105+
}
106+
101107
impl PartialEq for dyn LogicalType {
102108
fn eq(&self, other: &Self) -> bool {
103109
self.signature().eq(&other.signature())

datafusion/common/src/types/native.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::compute::can_cast_types;
2424
use arrow_schema::{
2525
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
2626
};
27-
use std::sync::Arc;
27+
use std::{fmt::Display, sync::Arc};
2828

2929
/// Representation of a type that DataFusion can handle natively. It is a subset
3030
/// of the physical variants in Arrow's native [`DataType`].
@@ -183,6 +183,12 @@ pub enum NativeType {
183183
Map(LogicalFieldRef),
184184
}
185185

186+
impl Display for NativeType {
187+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188+
write!(f, "NativeType::{self:?}")
189+
}
190+
}
191+
186192
impl LogicalType for NativeType {
187193
fn native(&self) -> &NativeType {
188194
self
@@ -348,6 +354,12 @@ impl LogicalType for NativeType {
348354
// mapping solutions to provide backwards compatibility while transitioning from
349355
// the purely physical system to a logical / physical system.
350356

357+
impl From<&DataType> for NativeType {
358+
fn from(value: &DataType) -> Self {
359+
value.clone().into()
360+
}
361+
}
362+
351363
impl From<DataType> for NativeType {
352364
fn from(value: DataType) -> Self {
353365
use NativeType::*;
@@ -392,8 +404,33 @@ impl From<DataType> for NativeType {
392404
}
393405
}
394406

395-
impl From<&DataType> for NativeType {
396-
fn from(value: &DataType) -> Self {
397-
value.clone().into()
407+
impl NativeType {
408+
#[inline]
409+
pub fn is_numeric(&self) -> bool {
410+
use NativeType::*;
411+
matches!(
412+
self,
413+
UInt8
414+
| UInt16
415+
| UInt32
416+
| UInt64
417+
| Int8
418+
| Int16
419+
| Int32
420+
| Int64
421+
| Float16
422+
| Float32
423+
| Float64
424+
| Decimal(_, _)
425+
)
426+
}
427+
428+
#[inline]
429+
pub fn is_integer(&self) -> bool {
430+
use NativeType::*;
431+
matches!(
432+
self,
433+
UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64
434+
)
398435
}
399436
}

datafusion/expr-common/src/signature.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
//! and return types of functions in DataFusion.
2020
2121
use arrow::datatypes::DataType;
22+
use datafusion_common::types::LogicalTypeRef;
2223

2324
/// Constant that is used as a placeholder for any valid timezone.
2425
/// This is used where a function can accept a timestamp type with any
@@ -106,10 +107,10 @@ pub enum TypeSignature {
106107
/// Exact number of arguments of an exact type
107108
Exact(Vec<DataType>),
108109
/// The number of arguments that can be coerced to in order
109-
/// For example, `Coercible(vec![DataType::Float64])` accepts
110+
/// For example, `Coercible(vec![logical_float64()])` accepts
110111
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
111112
/// since i32 and f32 can be casted to f64
112-
Coercible(Vec<DataType>),
113+
Coercible(Vec<LogicalTypeRef>),
113114
/// Fixed number of arguments of arbitrary types
114115
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
115116
Any(usize),
@@ -123,7 +124,9 @@ pub enum TypeSignature {
123124
/// Specifies Signatures for array functions
124125
ArraySignature(ArrayFunctionSignature),
125126
/// Fixed number of arguments of numeric types.
126-
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
127+
/// See [`NativeType::is_numeric`] to know which type is considered numeric
128+
///
129+
/// [`NativeType::is_numeric`]: datafusion_common
127130
Numeric(usize),
128131
/// Fixed number of arguments of all the same string types.
129132
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
@@ -201,7 +204,10 @@ impl TypeSignature {
201204
TypeSignature::Numeric(num) => {
202205
vec![format!("Numeric({num})")]
203206
}
204-
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
207+
TypeSignature::Coercible(types) => {
208+
vec![Self::join_types(types, ", ")]
209+
}
210+
TypeSignature::Exact(types) => {
205211
vec![Self::join_types(types, ", ")]
206212
}
207213
TypeSignature::Any(arg_count) => {
@@ -322,7 +328,7 @@ impl Signature {
322328
}
323329
}
324330
/// Target coerce types in order
325-
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
331+
pub fn coercible(target_types: Vec<LogicalTypeRef>, volatility: Volatility) -> Self {
326332
Self {
327333
type_signature: TypeSignature::Coercible(target_types),
328334
volatility,

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 94 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::{
2323
};
2424
use datafusion_common::{
2525
exec_err, internal_datafusion_err, internal_err, plan_err,
26+
types::{LogicalType, NativeType},
2627
utils::{coerced_fixed_size_list_to_list, list_ndims},
2728
Result,
2829
};
@@ -395,40 +396,56 @@ fn get_valid_types(
395396
}
396397
}
397398

399+
fn function_length_check(length: usize, expected_length: usize) -> Result<()> {
400+
if length < 1 {
401+
return plan_err!(
402+
"The signature expected at least one argument but received {expected_length}"
403+
);
404+
}
405+
406+
if length != expected_length {
407+
return plan_err!(
408+
"The signature expected {length} arguments but received {expected_length}"
409+
);
410+
}
411+
412+
Ok(())
413+
}
414+
398415
let valid_types = match signature {
399416
TypeSignature::Variadic(valid_types) => valid_types
400417
.iter()
401418
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
402419
.collect(),
403420
TypeSignature::String(number) => {
404-
if *number < 1 {
405-
return plan_err!(
406-
"The signature expected at least one argument but received {}",
407-
current_types.len()
408-
);
409-
}
410-
if *number != current_types.len() {
411-
return plan_err!(
412-
"The signature expected {} arguments but received {}",
413-
number,
414-
current_types.len()
415-
);
421+
function_length_check(current_types.len(), *number)?;
422+
423+
let mut new_types = Vec::with_capacity(current_types.len());
424+
for data_type in current_types.iter() {
425+
let logical_data_type: NativeType = data_type.into();
426+
if logical_data_type == NativeType::String {
427+
new_types.push(data_type.to_owned());
428+
} else if logical_data_type == NativeType::Null {
429+
// TODO: Switch to Utf8View if all the string functions supports Utf8View
430+
new_types.push(DataType::Utf8);
431+
} else {
432+
return plan_err!(
433+
"The signature expected NativeType::String but received {logical_data_type}"
434+
);
435+
}
416436
}
417437

418-
fn coercion_rule(
438+
// Find the common string type for the given types
439+
fn find_common_type(
419440
lhs_type: &DataType,
420441
rhs_type: &DataType,
421442
) -> Result<DataType> {
422443
match (lhs_type, rhs_type) {
423-
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
424-
(DataType::Null, data_type) | (data_type, DataType::Null) => {
425-
coercion_rule(data_type, &DataType::Utf8)
426-
}
427444
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
428-
coercion_rule(lhs, rhs)
445+
find_common_type(lhs, rhs)
429446
}
430447
(DataType::Dictionary(_, v), other)
431-
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
448+
| (other, DataType::Dictionary(_, v)) => find_common_type(v, other),
432449
_ => {
433450
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
434451
Ok(coerced_type)
@@ -444,15 +461,13 @@ fn get_valid_types(
444461
}
445462

446463
// Length checked above, safe to unwrap
447-
let mut coerced_type = current_types.first().unwrap().to_owned();
448-
for t in current_types.iter().skip(1) {
449-
coerced_type = coercion_rule(&coerced_type, t)?;
464+
let mut coerced_type = new_types.first().unwrap().to_owned();
465+
for t in new_types.iter().skip(1) {
466+
coerced_type = find_common_type(&coerced_type, t)?;
450467
}
451468

452469
fn base_type_or_default_type(data_type: &DataType) -> DataType {
453-
if data_type.is_null() {
454-
DataType::Utf8
455-
} else if let DataType::Dictionary(_, v) = data_type {
470+
if let DataType::Dictionary(_, v) = data_type {
456471
base_type_or_default_type(v)
457472
} else {
458473
data_type.to_owned()
@@ -462,22 +477,22 @@ fn get_valid_types(
462477
vec![vec![base_type_or_default_type(&coerced_type); *number]]
463478
}
464479
TypeSignature::Numeric(number) => {
465-
if *number < 1 {
466-
return plan_err!(
467-
"The signature expected at least one argument but received {}",
468-
current_types.len()
469-
);
470-
}
471-
if *number != current_types.len() {
472-
return plan_err!(
473-
"The signature expected {} arguments but received {}",
474-
number,
475-
current_types.len()
476-
);
477-
}
480+
function_length_check(current_types.len(), *number)?;
478481

479-
let mut valid_type = current_types.first().unwrap().clone();
482+
// Find common numeric type amongs given types except string
483+
let mut valid_type = current_types.first().unwrap().to_owned();
480484
for t in current_types.iter().skip(1) {
485+
let logical_data_type: NativeType = t.into();
486+
if logical_data_type == NativeType::Null {
487+
continue;
488+
}
489+
490+
if !logical_data_type.is_numeric() {
491+
return plan_err!(
492+
"The signature expected NativeType::Numeric but received {logical_data_type}"
493+
);
494+
}
495+
481496
if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
482497
valid_type = coerced_type;
483498
} else {
@@ -489,31 +504,55 @@ fn get_valid_types(
489504
}
490505
}
491506

507+
let logical_data_type: NativeType = valid_type.clone().into();
508+
// Fallback to default type if we don't know which type to coerced to
509+
// f64 is chosen since most of the math functions utilize Signature::numeric,
510+
// and their default type is double precision
511+
if logical_data_type == NativeType::Null {
512+
valid_type = DataType::Float64;
513+
}
514+
492515
vec![vec![valid_type; *number]]
493516
}
494517
TypeSignature::Coercible(target_types) => {
495-
if target_types.is_empty() {
496-
return plan_err!(
497-
"The signature expected at least one argument but received {}",
498-
current_types.len()
499-
);
500-
}
501-
if target_types.len() != current_types.len() {
502-
return plan_err!(
503-
"The signature expected {} arguments but received {}",
504-
target_types.len(),
505-
current_types.len()
506-
);
518+
function_length_check(current_types.len(), target_types.len())?;
519+
520+
// Aim to keep this logic as SIMPLE as possible!
521+
// Make sure the corresponding test is covered
522+
// If this function becomes COMPLEX, create another new signature!
523+
fn can_coerce_to(
524+
logical_type: &NativeType,
525+
target_type: &NativeType,
526+
) -> bool {
527+
if logical_type == target_type {
528+
return true;
529+
}
530+
531+
if logical_type == &NativeType::Null {
532+
return true;
533+
}
534+
535+
if target_type.is_integer() && logical_type.is_integer() {
536+
return true;
537+
}
538+
539+
false
507540
}
508541

509-
for (data_type, target_type) in current_types.iter().zip(target_types.iter())
542+
let mut new_types = Vec::with_capacity(current_types.len());
543+
for (current_type, target_type) in
544+
current_types.iter().zip(target_types.iter())
510545
{
511-
if !can_cast_types(data_type, target_type) {
512-
return plan_err!("{data_type} is not coercible to {target_type}");
546+
let logical_type: NativeType = current_type.into();
547+
let target_logical_type = target_type.native();
548+
if can_coerce_to(&logical_type, target_logical_type) {
549+
let target_type =
550+
target_logical_type.default_cast_for(current_type)?;
551+
new_types.push(target_type);
513552
}
514553
}
515554

516-
vec![target_types.to_owned()]
555+
vec![new_types]
517556
}
518557
TypeSignature::Uniform(number, valid_types) => valid_types
519558
.iter()

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL;
3333
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3434
use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity};
3535
use datafusion_expr::{
36-
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr,
37-
ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility,
36+
Accumulator, AggregateUDFImpl, Documentation, Expr, ExprFunctionExt, Signature,
37+
SortExpr, Volatility,
3838
};
3939
use datafusion_functions_aggregate_common::utils::get_sort_options;
4040
use datafusion_physical_expr_common::sort_expr::LexOrdering;
@@ -79,15 +79,7 @@ impl Default for FirstValue {
7979
impl FirstValue {
8080
pub fn new() -> Self {
8181
Self {
82-
signature: Signature::one_of(
83-
vec![
84-
// TODO: we can introduce more strict signature that only numeric of array types are allowed
85-
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
86-
TypeSignature::Numeric(1),
87-
TypeSignature::Uniform(1, vec![DataType::Utf8]),
88-
],
89-
Volatility::Immutable,
90-
),
82+
signature: Signature::any(1, Volatility::Immutable),
9183
requirement_satisfied: false,
9284
}
9385
}
@@ -406,15 +398,7 @@ impl Default for LastValue {
406398
impl LastValue {
407399
pub fn new() -> Self {
408400
Self {
409-
signature: Signature::one_of(
410-
vec![
411-
// TODO: we can introduce more strict signature that only numeric of array types are allowed
412-
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
413-
TypeSignature::Numeric(1),
414-
TypeSignature::Uniform(1, vec![DataType::Utf8]),
415-
],
416-
Volatility::Immutable,
417-
),
401+
signature: Signature::any(1, Volatility::Immutable),
418402
requirement_satisfied: false,
419403
}
420404
}

0 commit comments

Comments
 (0)