Skip to content

Commit 24f9e49

Browse files
committed
refactor: use TypeSignature::Coercible for math functions
1 parent 680383b commit 24f9e49

File tree

19 files changed

+309
-77
lines changed

19 files changed

+309
-77
lines changed

datafusion/common/src/types/native.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ impl NativeType {
436436
)
437437
}
438438

439+
#[inline]
440+
pub fn is_float(&self) -> bool {
441+
use NativeType::*;
442+
matches!(self, Float16 | Float32 | Float64)
443+
}
444+
439445
#[inline]
440446
pub fn is_integer(&self) -> bool {
441447
use NativeType::*;

datafusion/expr-common/src/signature.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ pub enum TypeSignatureClass {
215215
Interval,
216216
Duration,
217217
Native(LogicalTypeRef),
218-
// TODO:
219-
// Numeric
218+
Numeric,
219+
Float,
220220
Integer,
221221
}
222222

@@ -252,6 +252,16 @@ impl TypeSignatureClass {
252252
TypeSignatureClass::Duration => {
253253
vec![DataType::Duration(TimeUnit::Nanosecond)]
254254
}
255+
TypeSignatureClass::Numeric => {
256+
vec![
257+
DataType::Int64,
258+
DataType::Float64,
259+
DataType::Decimal256(3, -2),
260+
]
261+
}
262+
TypeSignatureClass::Float => {
263+
vec![DataType::Float64]
264+
}
255265
TypeSignatureClass::Integer => {
256266
vec![DataType::Int64]
257267
}
@@ -269,6 +279,8 @@ impl TypeSignatureClass {
269279
TypeSignatureClass::Time if logical_type.is_time() => true,
270280
TypeSignatureClass::Interval if logical_type.is_interval() => true,
271281
TypeSignatureClass::Duration if logical_type.is_duration() => true,
282+
TypeSignatureClass::Numeric if logical_type.is_numeric() => true,
283+
TypeSignatureClass::Float if logical_type.is_float() => true,
272284
TypeSignatureClass::Integer if logical_type.is_integer() => true,
273285
_ => false,
274286
}
@@ -297,6 +309,12 @@ impl TypeSignatureClass {
297309
TypeSignatureClass::Duration if native_type.is_duration() => {
298310
Ok(origin_type.to_owned())
299311
}
312+
TypeSignatureClass::Numeric if native_type.is_numeric() => {
313+
Ok(origin_type.to_owned())
314+
}
315+
TypeSignatureClass::Float if native_type.is_float() => {
316+
Ok(origin_type.to_owned())
317+
}
300318
TypeSignatureClass::Integer if native_type.is_integer() => {
301319
Ok(origin_type.to_owned())
302320
}

datafusion/functions/src/macros.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,16 @@ macro_rules! make_math_unary_udf {
160160

161161
use arrow::array::{ArrayRef, AsArray};
162162
use arrow::datatypes::{DataType, Float32Type, Float64Type};
163+
use datafusion_common::types::logical_null;
164+
use datafusion_common::types::NativeType;
163165
use datafusion_common::{exec_err, Result};
164166
use datafusion_expr::interval_arithmetic::Interval;
165167
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
166168
use datafusion_expr::{
167169
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
168-
Signature, Volatility,
170+
Signature, TypeSignatureClass, Volatility,
169171
};
172+
use datafusion_expr_common::signature::Coercion;
170173

171174
#[derive(Debug)]
172175
pub struct $UDF {
@@ -175,11 +178,16 @@ macro_rules! make_math_unary_udf {
175178

176179
impl $UDF {
177180
pub fn new() -> Self {
178-
use DataType::*;
179181
Self {
180-
signature: Signature::uniform(
181-
1,
182-
vec![Float64, Float32],
182+
signature: Signature::coercible(
183+
vec![Coercion::new_implicit(
184+
TypeSignatureClass::Float,
185+
vec![
186+
TypeSignatureClass::Integer,
187+
TypeSignatureClass::Native(logical_null()),
188+
],
189+
NativeType::Float64,
190+
)],
183191
Volatility::Immutable,
184192
),
185193
}

datafusion/functions/src/math/abs.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,17 @@ use arrow::array::{
2626
};
2727
use arrow::datatypes::DataType;
2828
use arrow::error::ArrowError;
29+
use datafusion_common::types::{logical_null, NativeType};
2930
use datafusion_common::{
3031
internal_datafusion_err, not_impl_err, utils::take_function_args, Result,
3132
};
3233
use datafusion_expr::interval_arithmetic::Interval;
3334
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
3435
use datafusion_expr::{
3536
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36-
Volatility,
37+
TypeSignature, TypeSignatureClass, Volatility,
3738
};
39+
use datafusion_expr_common::signature::Coercion;
3840
use datafusion_macros::user_doc;
3941

4042
type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
@@ -126,7 +128,14 @@ impl Default for AbsFunc {
126128
impl AbsFunc {
127129
pub fn new() -> Self {
128130
Self {
129-
signature: Signature::numeric(1, Volatility::Immutable),
131+
signature: Signature::new(
132+
TypeSignature::Coercible(vec![Coercion::new_implicit(
133+
TypeSignatureClass::Numeric,
134+
vec![TypeSignatureClass::Native(logical_null())],
135+
NativeType::Float64,
136+
)]),
137+
Volatility::Immutable,
138+
),
130139
}
131140
}
132141
}

datafusion/functions/src/math/cot.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@ use std::sync::Arc;
2121
use arrow::array::{ArrayRef, AsArray};
2222
use arrow::datatypes::DataType::{Float32, Float64};
2323
use arrow::datatypes::{DataType, Float32Type, Float64Type};
24+
use datafusion_common::types::NativeType;
25+
use datafusion_expr_common::signature::Coercion;
2426

2527
use crate::utils::make_scalar_function;
2628
use datafusion_common::{exec_err, Result};
27-
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
29+
use datafusion_expr::{
30+
ColumnarValue, Documentation, ScalarFunctionArgs, TypeSignature, TypeSignatureClass,
31+
};
2832
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2933
use datafusion_macros::user_doc;
3034

@@ -47,16 +51,18 @@ impl Default for CotFunc {
4751

4852
impl CotFunc {
4953
pub fn new() -> Self {
50-
use DataType::*;
5154
Self {
5255
// math expressions expect 1 argument of type f64 or f32
5356
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
5457
// return the best approximation for it (in f64).
5558
// We accept f32 because in this case it is clear that the best approximation
5659
// will be as good as the number of digits in the number
57-
signature: Signature::uniform(
58-
1,
59-
vec![Float64, Float32],
60+
signature: Signature::new(
61+
TypeSignature::Coercible(vec![Coercion::new_implicit(
62+
TypeSignatureClass::Float,
63+
vec![TypeSignatureClass::Integer],
64+
NativeType::Float64,
65+
)]),
6066
Volatility::Immutable,
6167
),
6268
}

datafusion/functions/src/math/factorial.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use arrow::{
1919
array::{ArrayRef, Int64Array},
2020
error::ArrowError,
2121
};
22+
use datafusion_expr_common::signature::Coercion;
2223
use std::any::Any;
2324
use std::sync::Arc;
2425

@@ -27,11 +28,13 @@ use arrow::datatypes::DataType::Int64;
2728

2829
use crate::utils::make_scalar_function;
2930
use datafusion_common::{
30-
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
31+
arrow_datafusion_err, exec_err, internal_datafusion_err,
32+
types::{logical_int64, logical_null, NativeType},
33+
DataFusionError, Result,
3134
};
3235
use datafusion_expr::{
3336
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34-
Volatility,
37+
TypeSignature, TypeSignatureClass, Volatility,
3538
};
3639
use datafusion_macros::user_doc;
3740

@@ -55,7 +58,17 @@ impl Default for FactorialFunc {
5558
impl FactorialFunc {
5659
pub fn new() -> Self {
5760
Self {
58-
signature: Signature::uniform(1, vec![Int64], Volatility::Immutable),
61+
signature: Signature::new(
62+
TypeSignature::Coercible(vec![Coercion::new_implicit(
63+
TypeSignatureClass::Native(logical_int64()),
64+
vec![
65+
TypeSignatureClass::Integer,
66+
TypeSignatureClass::Native(logical_null()),
67+
],
68+
NativeType::Int64,
69+
)]),
70+
Volatility::Immutable,
71+
),
5972
}
6073
}
6174
}

datafusion/functions/src/math/gcd.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray
1919
use arrow::compute::try_binary;
2020
use arrow::datatypes::{DataType, Int64Type};
2121
use arrow::error::ArrowError;
22+
use datafusion_common::types::{logical_int64, logical_null, NativeType};
2223
use std::any::Any;
2324
use std::mem::swap;
2425
use std::sync::Arc;
2526

2627
use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
2728
use datafusion_expr::{
2829
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29-
Volatility,
30+
TypeSignatureClass, Volatility,
3031
};
32+
use datafusion_expr_common::signature::Coercion;
3133
use datafusion_macros::user_doc;
3234

3335
#[user_doc(
@@ -51,9 +53,25 @@ impl Default for GcdFunc {
5153
impl GcdFunc {
5254
pub fn new() -> Self {
5355
Self {
54-
signature: Signature::uniform(
55-
2,
56-
vec![DataType::Int64],
56+
signature: Signature::coercible(
57+
vec![
58+
Coercion::new_implicit(
59+
TypeSignatureClass::Native(logical_int64()),
60+
vec![
61+
TypeSignatureClass::Integer,
62+
TypeSignatureClass::Native(logical_null()),
63+
],
64+
NativeType::Int64,
65+
),
66+
Coercion::new_implicit(
67+
TypeSignatureClass::Native(logical_int64()),
68+
vec![
69+
TypeSignatureClass::Integer,
70+
TypeSignatureClass::Native(logical_null()),
71+
],
72+
NativeType::Int64,
73+
),
74+
],
5775
Volatility::Immutable,
5876
),
5977
}

datafusion/functions/src/math/iszero.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ use arrow::array::{ArrayRef, AsArray, BooleanArray};
2222
use arrow::datatypes::DataType::{Boolean, Float32, Float64};
2323
use arrow::datatypes::{DataType, Float32Type, Float64Type};
2424

25+
use datafusion_common::types::{logical_null, NativeType};
2526
use datafusion_common::{exec_err, Result};
26-
use datafusion_expr::TypeSignature::Exact;
27+
use datafusion_expr::TypeSignature::{self};
2728
use datafusion_expr::{
2829
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29-
Volatility,
30+
TypeSignatureClass, Volatility,
3031
};
32+
use datafusion_expr_common::signature::Coercion;
3133
use datafusion_macros::user_doc;
3234

3335
use crate::utils::make_scalar_function;
@@ -51,10 +53,16 @@ impl Default for IsZeroFunc {
5153

5254
impl IsZeroFunc {
5355
pub fn new() -> Self {
54-
use DataType::*;
5556
Self {
56-
signature: Signature::one_of(
57-
vec![Exact(vec![Float32]), Exact(vec![Float64])],
57+
signature: Signature::new(
58+
TypeSignature::Coercible(vec![Coercion::new_implicit(
59+
TypeSignatureClass::Float,
60+
vec![
61+
TypeSignatureClass::Integer,
62+
TypeSignatureClass::Native(logical_null()),
63+
],
64+
NativeType::Float64,
65+
)]),
5866
Volatility::Immutable,
5967
),
6068
}

datafusion/functions/src/math/lcm.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@ use arrow::datatypes::DataType;
2323
use arrow::datatypes::DataType::Int64;
2424

2525
use arrow::error::ArrowError;
26+
use datafusion_common::types::{logical_int64, logical_null, NativeType};
2627
use datafusion_common::{
2728
arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
2829
};
2930
use datafusion_expr::{
3031
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
31-
Volatility,
32+
TypeSignature, TypeSignatureClass, Volatility,
3233
};
34+
use datafusion_expr_common::signature::Coercion;
3335
use datafusion_macros::user_doc;
3436

3537
use super::gcd::unsigned_gcd;
@@ -55,9 +57,28 @@ impl Default for LcmFunc {
5557

5658
impl LcmFunc {
5759
pub fn new() -> Self {
58-
use DataType::*;
5960
Self {
60-
signature: Signature::uniform(2, vec![Int64], Volatility::Immutable),
61+
signature: Signature::new(
62+
TypeSignature::Coercible(vec![
63+
Coercion::new_implicit(
64+
TypeSignatureClass::Native(logical_int64()),
65+
vec![
66+
TypeSignatureClass::Integer,
67+
TypeSignatureClass::Native(logical_null()),
68+
],
69+
NativeType::Int64,
70+
),
71+
Coercion::new_implicit(
72+
TypeSignatureClass::Native(logical_int64()),
73+
vec![
74+
TypeSignatureClass::Integer,
75+
TypeSignatureClass::Native(logical_null()),
76+
],
77+
NativeType::Int64,
78+
),
79+
]),
80+
Volatility::Immutable,
81+
),
6182
}
6283
}
6384
}

0 commit comments

Comments
 (0)