Skip to content

Commit f2cdc14

Browse files
authored
replace TypeSignature::String with TypeSignature::Coercible for trim functions (#14865)
* replace type signature for trim functions * make clippy happy
1 parent 9278233 commit f2cdc14

File tree

5 files changed

+71
-19
lines changed

5 files changed

+71
-19
lines changed

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub use datafusion_expr_common::columnar_value::ColumnarValue;
7878
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
7979
pub use datafusion_expr_common::operator::Operator;
8080
pub use datafusion_expr_common::signature::{
81-
ArrayFunctionArgument, ArrayFunctionSignature, Signature, TypeSignature,
81+
ArrayFunctionArgument, ArrayFunctionSignature, Coercion, Signature, TypeSignature,
8282
TypeSignatureClass, Volatility, TIMEZONE_WILDCARD,
8383
};
8484
pub use datafusion_expr_common::type_coercion::binary;

datafusion/functions/src/string/btrim.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,28 @@ use crate::string::common::*;
1919
use crate::utils::{make_scalar_function, utf8_to_str_type};
2020
use arrow::array::{ArrayRef, OffsetSizeTrait};
2121
use arrow::datatypes::DataType;
22+
use datafusion_common::types::logical_string;
2223
use datafusion_common::{exec_err, Result};
2324
use datafusion_expr::function::Hint;
2425
use datafusion_expr::{
25-
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
26-
TypeSignature, Volatility,
26+
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27+
TypeSignature, TypeSignatureClass, Volatility,
2728
};
2829
use datafusion_macros::user_doc;
2930
use std::any::Any;
31+
use std::sync::Arc;
3032

3133
/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
3234
/// btrim('xyxtrimyyx', 'xyz') = 'trim'
3335
fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
3436
let use_string_view = args[0].data_type() == &DataType::Utf8View;
35-
general_trim::<T>(args, TrimType::Both, use_string_view)
37+
let args = if args.len() > 1 {
38+
let arg1 = arrow::compute::kernels::cast::cast(&args[1], args[0].data_type())?;
39+
vec![Arc::clone(&args[0]), arg1]
40+
} else {
41+
args.to_owned()
42+
};
43+
general_trim::<T>(&args, TrimType::Both, use_string_view)
3644
}
3745

3846
#[user_doc(
@@ -73,7 +81,15 @@ impl BTrimFunc {
7381
pub fn new() -> Self {
7482
Self {
7583
signature: Signature::one_of(
76-
vec![TypeSignature::String(2), TypeSignature::String(1)],
84+
vec![
85+
TypeSignature::Coercible(vec![
86+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
87+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
88+
]),
89+
TypeSignature::Coercible(vec![Coercion::new_exact(
90+
TypeSignatureClass::Native(logical_string()),
91+
)]),
92+
],
7793
Volatility::Immutable,
7894
),
7995
aliases: vec![String::from("trim")],

datafusion/functions/src/string/ltrim.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,30 @@
1818
use arrow::array::{ArrayRef, OffsetSizeTrait};
1919
use arrow::datatypes::DataType;
2020
use std::any::Any;
21+
use std::sync::Arc;
2122

2223
use crate::string::common::*;
2324
use crate::utils::{make_scalar_function, utf8_to_str_type};
25+
use datafusion_common::types::logical_string;
2426
use datafusion_common::{exec_err, Result};
2527
use datafusion_expr::function::Hint;
26-
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility};
27-
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
28+
use datafusion_expr::{
29+
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30+
TypeSignature, TypeSignatureClass, Volatility,
31+
};
2832
use datafusion_macros::user_doc;
2933

3034
/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed.
3135
/// ltrim('zzzytest', 'xyz') = 'test'
3236
fn ltrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
3337
let use_string_view = args[0].data_type() == &DataType::Utf8View;
34-
general_trim::<T>(args, TrimType::Left, use_string_view)
38+
let args = if args.len() > 1 {
39+
let arg1 = arrow::compute::kernels::cast::cast(&args[1], args[0].data_type())?;
40+
vec![Arc::clone(&args[0]), arg1]
41+
} else {
42+
args.to_owned()
43+
};
44+
general_trim::<T>(&args, TrimType::Left, use_string_view)
3545
}
3646

3747
#[user_doc(
@@ -76,7 +86,15 @@ impl LtrimFunc {
7686
pub fn new() -> Self {
7787
Self {
7888
signature: Signature::one_of(
79-
vec![TypeSignature::String(2), TypeSignature::String(1)],
89+
vec![
90+
TypeSignature::Coercible(vec![
91+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
92+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
93+
]),
94+
TypeSignature::Coercible(vec![Coercion::new_exact(
95+
TypeSignatureClass::Native(logical_string()),
96+
)]),
97+
],
8098
Volatility::Immutable,
8199
),
82100
}

datafusion/functions/src/string/rtrim.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,30 @@
1818
use arrow::array::{ArrayRef, OffsetSizeTrait};
1919
use arrow::datatypes::DataType;
2020
use std::any::Any;
21+
use std::sync::Arc;
2122

2223
use crate::string::common::*;
2324
use crate::utils::{make_scalar_function, utf8_to_str_type};
25+
use datafusion_common::types::logical_string;
2426
use datafusion_common::{exec_err, Result};
2527
use datafusion_expr::function::Hint;
26-
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility};
27-
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
28+
use datafusion_expr::{
29+
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30+
TypeSignature, TypeSignatureClass, Volatility,
31+
};
2832
use datafusion_macros::user_doc;
2933

3034
/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed.
3135
/// rtrim('testxxzx', 'xyz') = 'test'
3236
fn rtrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
3337
let use_string_view = args[0].data_type() == &DataType::Utf8View;
34-
general_trim::<T>(args, TrimType::Right, use_string_view)
38+
let args = if args.len() > 1 {
39+
let arg1 = arrow::compute::kernels::cast::cast(&args[1], args[0].data_type())?;
40+
vec![Arc::clone(&args[0]), arg1]
41+
} else {
42+
args.to_owned()
43+
};
44+
general_trim::<T>(&args, TrimType::Right, use_string_view)
3545
}
3646

3747
#[user_doc(
@@ -76,7 +86,15 @@ impl RtrimFunc {
7686
pub fn new() -> Self {
7787
Self {
7888
signature: Signature::one_of(
79-
vec![TypeSignature::String(2), TypeSignature::String(1)],
89+
vec![
90+
TypeSignature::Coercible(vec![
91+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
92+
Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
93+
]),
94+
TypeSignature::Coercible(vec![Coercion::new_exact(
95+
TypeSignatureClass::Native(logical_string()),
96+
)]),
97+
],
8098
Volatility::Immutable,
8199
),
82100
}

datafusion/sqllogictest/test_files/string/string_view.slt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ EXPLAIN SELECT
530530
FROM test;
531531
----
532532
logical_plan
533-
01)Projection: btrim(test.column1_utf8view, Utf8View("foo")) AS l
533+
01)Projection: btrim(test.column1_utf8view, Utf8("foo")) AS l
534534
02)--TableScan: test projection=[column1_utf8view]
535535

536536
# Test BTRIM with Utf8View bytes longer than 12
@@ -540,7 +540,7 @@ EXPLAIN SELECT
540540
FROM test;
541541
----
542542
logical_plan
543-
01)Projection: btrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l
543+
01)Projection: btrim(test.column1_utf8view, Utf8("this is longer than 12")) AS l
544544
02)--TableScan: test projection=[column1_utf8view]
545545

546546
## Ensure no casts for LTRIM
@@ -561,7 +561,7 @@ EXPLAIN SELECT
561561
FROM test;
562562
----
563563
logical_plan
564-
01)Projection: ltrim(test.column1_utf8view, Utf8View("foo")) AS l
564+
01)Projection: ltrim(test.column1_utf8view, Utf8("foo")) AS l
565565
02)--TableScan: test projection=[column1_utf8view]
566566

567567
# Test LTRIM with Utf8View bytes longer than 12
@@ -571,7 +571,7 @@ EXPLAIN SELECT
571571
FROM test;
572572
----
573573
logical_plan
574-
01)Projection: ltrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l
574+
01)Projection: ltrim(test.column1_utf8view, Utf8("this is longer than 12")) AS l
575575
02)--TableScan: test projection=[column1_utf8view]
576576

577577
## ensure no casts for RTRIM
@@ -592,7 +592,7 @@ EXPLAIN SELECT
592592
FROM test;
593593
----
594594
logical_plan
595-
01)Projection: rtrim(test.column1_utf8view, Utf8View("foo")) AS l
595+
01)Projection: rtrim(test.column1_utf8view, Utf8("foo")) AS l
596596
02)--TableScan: test projection=[column1_utf8view]
597597

598598
# Test RTRIM with Utf8View bytes longer than 12
@@ -602,7 +602,7 @@ EXPLAIN SELECT
602602
FROM test;
603603
----
604604
logical_plan
605-
01)Projection: rtrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l
605+
01)Projection: rtrim(test.column1_utf8view, Utf8("this is longer than 12")) AS l
606606
02)--TableScan: test projection=[column1_utf8view]
607607

608608
## Ensure no casts for CHARACTER_LENGTH

0 commit comments

Comments
 (0)