Skip to content

Commit d1e6fed

Browse files
delamarch3goldmedal
authored andcommitted
Implement get_possible_types for Uniform, Coercible, Variadic, Numeric and String (apache#13313)
* implement get_possible_types for Uniform, Coercible, Variadic, Numeric and String * fix possible types for variadic * use cloned * add Utf8View to STRINGS * add todo to support other native types Co-authored-by: Jax Liu <[email protected]> --------- Co-authored-by: Jax Liu <[email protected]>
1 parent 130e149 commit d1e6fed

File tree

2 files changed

+123
-9
lines changed

2 files changed

+123
-9
lines changed

datafusion/expr-common/src/signature.rs

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
//! Signature module contains foundational types that are used to represent signatures, types,
1919
//! and return types of functions in DataFusion.
2020
21+
use crate::type_coercion::aggregates::{NUMERICS, STRINGS};
2122
use arrow::datatypes::DataType;
22-
use datafusion_common::types::LogicalTypeRef;
23+
use datafusion_common::types::{LogicalTypeRef, NativeType};
24+
use itertools::Itertools;
2325

2426
/// Constant that is used as a placeholder for any valid timezone.
2527
/// This is used where a function can accept a timestamp type with any
@@ -258,17 +260,66 @@ impl TypeSignature {
258260
.iter()
259261
.flat_map(|type_sig| type_sig.get_possible_types())
260262
.collect(),
263+
TypeSignature::Uniform(arg_count, types) => types
264+
.iter()
265+
.cloned()
266+
.map(|data_type| vec![data_type; *arg_count])
267+
.collect(),
268+
TypeSignature::Coercible(types) => types
269+
.iter()
270+
.map(|logical_type| get_data_types(logical_type.native()))
271+
.multi_cartesian_product()
272+
.collect(),
273+
TypeSignature::Variadic(types) => types
274+
.iter()
275+
.cloned()
276+
.map(|data_type| vec![data_type])
277+
.collect(),
278+
TypeSignature::Numeric(arg_count) => NUMERICS
279+
.iter()
280+
.cloned()
281+
.map(|numeric_type| vec![numeric_type; *arg_count])
282+
.collect(),
283+
TypeSignature::String(arg_count) => STRINGS
284+
.iter()
285+
.cloned()
286+
.map(|string_type| vec![string_type; *arg_count])
287+
.collect(),
261288
// TODO: Implement for other types
262-
TypeSignature::Uniform(_, _)
263-
| TypeSignature::Coercible(_)
264-
| TypeSignature::Any(_)
265-
| TypeSignature::Variadic(_)
289+
TypeSignature::Any(_)
266290
| TypeSignature::VariadicAny
267-
| TypeSignature::UserDefined
268291
| TypeSignature::ArraySignature(_)
269-
| TypeSignature::Numeric(_)
270-
| TypeSignature::String(_) => vec![],
292+
| TypeSignature::UserDefined => vec![],
293+
}
294+
}
295+
}
296+
297+
fn get_data_types(native_type: &NativeType) -> Vec<DataType> {
298+
match native_type {
299+
NativeType::Null => vec![DataType::Null],
300+
NativeType::Boolean => vec![DataType::Boolean],
301+
NativeType::Int8 => vec![DataType::Int8],
302+
NativeType::Int16 => vec![DataType::Int16],
303+
NativeType::Int32 => vec![DataType::Int32],
304+
NativeType::Int64 => vec![DataType::Int64],
305+
NativeType::UInt8 => vec![DataType::UInt8],
306+
NativeType::UInt16 => vec![DataType::UInt16],
307+
NativeType::UInt32 => vec![DataType::UInt32],
308+
NativeType::UInt64 => vec![DataType::UInt64],
309+
NativeType::Float16 => vec![DataType::Float16],
310+
NativeType::Float32 => vec![DataType::Float32],
311+
NativeType::Float64 => vec![DataType::Float64],
312+
NativeType::Date => vec![DataType::Date32, DataType::Date64],
313+
NativeType::Binary => vec![
314+
DataType::Binary,
315+
DataType::LargeBinary,
316+
DataType::BinaryView,
317+
],
318+
NativeType::String => {
319+
vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]
271320
}
321+
// TODO: support other native types
322+
_ => vec![],
272323
}
273324
}
274325

@@ -417,6 +468,8 @@ impl Signature {
417468

418469
#[cfg(test)]
419470
mod tests {
471+
use datafusion_common::types::{logical_int64, logical_string};
472+
420473
use super::*;
421474

422475
#[test]
@@ -515,5 +568,65 @@ mod tests {
515568
vec![DataType::Utf8]
516569
]
517570
);
571+
572+
let type_signature =
573+
TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]);
574+
let possible_types = type_signature.get_possible_types();
575+
assert_eq!(
576+
possible_types,
577+
vec![
578+
vec![DataType::Float32, DataType::Float32],
579+
vec![DataType::Int64, DataType::Int64]
580+
]
581+
);
582+
583+
let type_signature =
584+
TypeSignature::Coercible(vec![logical_string(), logical_int64()]);
585+
let possible_types = type_signature.get_possible_types();
586+
assert_eq!(
587+
possible_types,
588+
vec![
589+
vec![DataType::Utf8, DataType::Int64],
590+
vec![DataType::LargeUtf8, DataType::Int64],
591+
vec![DataType::Utf8View, DataType::Int64]
592+
]
593+
);
594+
595+
let type_signature =
596+
TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]);
597+
let possible_types = type_signature.get_possible_types();
598+
assert_eq!(
599+
possible_types,
600+
vec![vec![DataType::Int32], vec![DataType::Int64]]
601+
);
602+
603+
let type_signature = TypeSignature::Numeric(2);
604+
let possible_types = type_signature.get_possible_types();
605+
assert_eq!(
606+
possible_types,
607+
vec![
608+
vec![DataType::Int8, DataType::Int8],
609+
vec![DataType::Int16, DataType::Int16],
610+
vec![DataType::Int32, DataType::Int32],
611+
vec![DataType::Int64, DataType::Int64],
612+
vec![DataType::UInt8, DataType::UInt8],
613+
vec![DataType::UInt16, DataType::UInt16],
614+
vec![DataType::UInt32, DataType::UInt32],
615+
vec![DataType::UInt64, DataType::UInt64],
616+
vec![DataType::Float32, DataType::Float32],
617+
vec![DataType::Float64, DataType::Float64]
618+
]
619+
);
620+
621+
let type_signature = TypeSignature::String(2);
622+
let possible_types = type_signature.get_possible_types();
623+
assert_eq!(
624+
possible_types,
625+
vec![
626+
vec![DataType::Utf8, DataType::Utf8],
627+
vec![DataType::LargeUtf8, DataType::LargeUtf8],
628+
vec![DataType::Utf8View, DataType::Utf8View]
629+
]
630+
);
518631
}
519632
}

datafusion/expr-common/src/type_coercion/aggregates.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ use arrow::datatypes::{
2323

2424
use datafusion_common::{internal_err, plan_err, Result};
2525

26-
pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];
26+
pub static STRINGS: &[DataType] =
27+
&[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];
2728

2829
pub static SIGNED_INTEGERS: &[DataType] = &[
2930
DataType::Int8,

0 commit comments

Comments
 (0)