|
18 | 18 | //! Signature module contains foundational types that are used to represent signatures, types,
|
19 | 19 | //! and return types of functions in DataFusion.
|
20 | 20 |
|
| 21 | +use crate::type_coercion::aggregates::{NUMERICS, STRINGS}; |
21 | 22 | use arrow::datatypes::DataType;
|
22 |
| -use datafusion_common::types::LogicalTypeRef; |
| 23 | +use datafusion_common::types::{LogicalTypeRef, NativeType}; |
| 24 | +use itertools::Itertools; |
23 | 25 |
|
24 | 26 | /// Constant that is used as a placeholder for any valid timezone.
|
25 | 27 | /// This is used where a function can accept a timestamp type with any
|
@@ -258,17 +260,66 @@ impl TypeSignature {
|
258 | 260 | .iter()
|
259 | 261 | .flat_map(|type_sig| type_sig.get_possible_types())
|
260 | 262 | .collect(),
|
| 263 | + TypeSignature::Uniform(arg_count, types) => types |
| 264 | + .iter() |
| 265 | + .cloned() |
| 266 | + .map(|data_type| vec![data_type; *arg_count]) |
| 267 | + .collect(), |
| 268 | + TypeSignature::Coercible(types) => types |
| 269 | + .iter() |
| 270 | + .map(|logical_type| get_data_types(logical_type.native())) |
| 271 | + .multi_cartesian_product() |
| 272 | + .collect(), |
| 273 | + TypeSignature::Variadic(types) => types |
| 274 | + .iter() |
| 275 | + .cloned() |
| 276 | + .map(|data_type| vec![data_type]) |
| 277 | + .collect(), |
| 278 | + TypeSignature::Numeric(arg_count) => NUMERICS |
| 279 | + .iter() |
| 280 | + .cloned() |
| 281 | + .map(|numeric_type| vec![numeric_type; *arg_count]) |
| 282 | + .collect(), |
| 283 | + TypeSignature::String(arg_count) => STRINGS |
| 284 | + .iter() |
| 285 | + .cloned() |
| 286 | + .map(|string_type| vec![string_type; *arg_count]) |
| 287 | + .collect(), |
261 | 288 | // TODO: Implement for other types
|
262 |
| - TypeSignature::Uniform(_, _) |
263 |
| - | TypeSignature::Coercible(_) |
264 |
| - | TypeSignature::Any(_) |
265 |
| - | TypeSignature::Variadic(_) |
| 289 | + TypeSignature::Any(_) |
266 | 290 | | TypeSignature::VariadicAny
|
267 |
| - | TypeSignature::UserDefined |
268 | 291 | | TypeSignature::ArraySignature(_)
|
269 |
| - | TypeSignature::Numeric(_) |
270 |
| - | TypeSignature::String(_) => vec![], |
| 292 | + | TypeSignature::UserDefined => vec![], |
| 293 | + } |
| 294 | + } |
| 295 | +} |
| 296 | + |
| 297 | +fn get_data_types(native_type: &NativeType) -> Vec<DataType> { |
| 298 | + match native_type { |
| 299 | + NativeType::Null => vec![DataType::Null], |
| 300 | + NativeType::Boolean => vec![DataType::Boolean], |
| 301 | + NativeType::Int8 => vec![DataType::Int8], |
| 302 | + NativeType::Int16 => vec![DataType::Int16], |
| 303 | + NativeType::Int32 => vec![DataType::Int32], |
| 304 | + NativeType::Int64 => vec![DataType::Int64], |
| 305 | + NativeType::UInt8 => vec![DataType::UInt8], |
| 306 | + NativeType::UInt16 => vec![DataType::UInt16], |
| 307 | + NativeType::UInt32 => vec![DataType::UInt32], |
| 308 | + NativeType::UInt64 => vec![DataType::UInt64], |
| 309 | + NativeType::Float16 => vec![DataType::Float16], |
| 310 | + NativeType::Float32 => vec![DataType::Float32], |
| 311 | + NativeType::Float64 => vec![DataType::Float64], |
| 312 | + NativeType::Date => vec![DataType::Date32, DataType::Date64], |
| 313 | + NativeType::Binary => vec![ |
| 314 | + DataType::Binary, |
| 315 | + DataType::LargeBinary, |
| 316 | + DataType::BinaryView, |
| 317 | + ], |
| 318 | + NativeType::String => { |
| 319 | + vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] |
271 | 320 | }
|
| 321 | + // TODO: support other native types |
| 322 | + _ => vec![], |
272 | 323 | }
|
273 | 324 | }
|
274 | 325 |
|
@@ -417,6 +468,8 @@ impl Signature {
|
417 | 468 |
|
418 | 469 | #[cfg(test)]
|
419 | 470 | mod tests {
|
| 471 | + use datafusion_common::types::{logical_int64, logical_string}; |
| 472 | + |
420 | 473 | use super::*;
|
421 | 474 |
|
422 | 475 | #[test]
|
@@ -515,5 +568,65 @@ mod tests {
|
515 | 568 | vec![DataType::Utf8]
|
516 | 569 | ]
|
517 | 570 | );
|
| 571 | + |
| 572 | + let type_signature = |
| 573 | + TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]); |
| 574 | + let possible_types = type_signature.get_possible_types(); |
| 575 | + assert_eq!( |
| 576 | + possible_types, |
| 577 | + vec![ |
| 578 | + vec![DataType::Float32, DataType::Float32], |
| 579 | + vec![DataType::Int64, DataType::Int64] |
| 580 | + ] |
| 581 | + ); |
| 582 | + |
| 583 | + let type_signature = |
| 584 | + TypeSignature::Coercible(vec![logical_string(), logical_int64()]); |
| 585 | + let possible_types = type_signature.get_possible_types(); |
| 586 | + assert_eq!( |
| 587 | + possible_types, |
| 588 | + vec![ |
| 589 | + vec![DataType::Utf8, DataType::Int64], |
| 590 | + vec![DataType::LargeUtf8, DataType::Int64], |
| 591 | + vec![DataType::Utf8View, DataType::Int64] |
| 592 | + ] |
| 593 | + ); |
| 594 | + |
| 595 | + let type_signature = |
| 596 | + TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]); |
| 597 | + let possible_types = type_signature.get_possible_types(); |
| 598 | + assert_eq!( |
| 599 | + possible_types, |
| 600 | + vec![vec![DataType::Int32], vec![DataType::Int64]] |
| 601 | + ); |
| 602 | + |
| 603 | + let type_signature = TypeSignature::Numeric(2); |
| 604 | + let possible_types = type_signature.get_possible_types(); |
| 605 | + assert_eq!( |
| 606 | + possible_types, |
| 607 | + vec![ |
| 608 | + vec![DataType::Int8, DataType::Int8], |
| 609 | + vec![DataType::Int16, DataType::Int16], |
| 610 | + vec![DataType::Int32, DataType::Int32], |
| 611 | + vec![DataType::Int64, DataType::Int64], |
| 612 | + vec![DataType::UInt8, DataType::UInt8], |
| 613 | + vec![DataType::UInt16, DataType::UInt16], |
| 614 | + vec![DataType::UInt32, DataType::UInt32], |
| 615 | + vec![DataType::UInt64, DataType::UInt64], |
| 616 | + vec![DataType::Float32, DataType::Float32], |
| 617 | + vec![DataType::Float64, DataType::Float64] |
| 618 | + ] |
| 619 | + ); |
| 620 | + |
| 621 | + let type_signature = TypeSignature::String(2); |
| 622 | + let possible_types = type_signature.get_possible_types(); |
| 623 | + assert_eq!( |
| 624 | + possible_types, |
| 625 | + vec![ |
| 626 | + vec![DataType::Utf8, DataType::Utf8], |
| 627 | + vec![DataType::LargeUtf8, DataType::LargeUtf8], |
| 628 | + vec![DataType::Utf8View, DataType::Utf8View] |
| 629 | + ] |
| 630 | + ); |
518 | 631 | }
|
519 | 632 | }
|
0 commit comments