Skip to content

Commit 47664df

Browse files
authored
Introduce Signature::String and return error if input of strpos is integer (#12751)
* fix sig Signed-off-by: jayzhan211 <[email protected]> * fix Signed-off-by: jayzhan211 <[email protected]> * fix error Signed-off-by: jayzhan211 <[email protected]> * fix all signature Signed-off-by: jayzhan211 <[email protected]> * fix all signature Signed-off-by: jayzhan211 <[email protected]> * change default type Signed-off-by: jayzhan211 <[email protected]> * clippy Signed-off-by: jayzhan211 <[email protected]> * fix docs Signed-off-by: jayzhan211 <[email protected]> * rm deadcode Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * cleanup Signed-off-by: jayzhan211 <[email protected]> * rm test Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent d8405ba commit 47664df

31 files changed

+184
-384
lines changed

datafusion/core/tests/expr_api/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ mod simplification;
3737
fn test_octet_length() {
3838
#[rustfmt::skip]
3939
evaluate_expr_test(
40-
octet_length(col("list")),
40+
octet_length(col("id")),
4141
vec![
4242
"+------+",
4343
"| expr |",
4444
"+------+",
45-
"| 5 |",
46-
"| 18 |",
47-
"| 6 |",
45+
"| 1 |",
46+
"| 1 |",
47+
"| 1 |",
4848
"+------+",
4949
],
5050
);

datafusion/expr-common/src/signature.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ pub enum TypeSignature {
125125
/// Fixed number of arguments of numeric types.
126126
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
127127
Numeric(usize),
128+
/// Fixed number of arguments of all the same string types.
129+
/// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8.
130+
/// Null is considerd as Utf8 by default
131+
/// Dictionary with string value type is also handled.
132+
String(usize),
128133
}
129134

130135
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
@@ -190,8 +195,11 @@ impl TypeSignature {
190195
.collect::<Vec<String>>()
191196
.join(", ")]
192197
}
198+
TypeSignature::String(num) => {
199+
vec![format!("String({num})")]
200+
}
193201
TypeSignature::Numeric(num) => {
194-
vec![format!("Numeric({})", num)]
202+
vec![format!("Numeric({num})")]
195203
}
196204
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
197205
vec![Self::join_types(types, ", ")]
@@ -280,6 +288,14 @@ impl Signature {
280288
}
281289
}
282290

291+
/// A specified number of numeric arguments
292+
pub fn string(arg_count: usize, volatility: Volatility) -> Self {
293+
Self {
294+
type_signature: TypeSignature::String(arg_count),
295+
volatility,
296+
}
297+
}
298+
283299
/// An arbitrary number of arguments of any type.
284300
pub fn variadic_any(volatility: Volatility) -> Self {
285301
Self {

datafusion/expr-common/src/type_coercion/binary.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ fn string_concat_internal_coercion(
959959
/// based on the observation that StringArray to StringViewArray is cheap but not vice versa.
960960
///
961961
/// Between Utf8 and LargeUtf8, we coerce to LargeUtf8.
962-
fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
962+
pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
963963
use arrow::datatypes::DataType::*;
964964
match (lhs_type, rhs_type) {
965965
// If Utf8View is in any side, we coerce to Utf8View.

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ use datafusion_common::{
2626
utils::{coerced_fixed_size_list_to_list, list_ndims},
2727
Result,
2828
};
29-
use datafusion_expr_common::signature::{
30-
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
29+
use datafusion_expr_common::{
30+
signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
31+
type_coercion::binary::string_coercion,
3132
};
3233
use std::sync::Arc;
3334

@@ -176,6 +177,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
176177
type_signature,
177178
TypeSignature::UserDefined
178179
| TypeSignature::Numeric(_)
180+
| TypeSignature::String(_)
179181
| TypeSignature::Coercible(_)
180182
| TypeSignature::Any(_)
181183
)
@@ -381,6 +383,67 @@ fn get_valid_types(
381383
.iter()
382384
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
383385
.collect(),
386+
TypeSignature::String(number) => {
387+
if *number < 1 {
388+
return plan_err!(
389+
"The signature expected at least one argument but received {}",
390+
current_types.len()
391+
);
392+
}
393+
if *number != current_types.len() {
394+
return plan_err!(
395+
"The signature expected {} arguments but received {}",
396+
number,
397+
current_types.len()
398+
);
399+
}
400+
401+
fn coercion_rule(
402+
lhs_type: &DataType,
403+
rhs_type: &DataType,
404+
) -> Result<DataType> {
405+
match (lhs_type, rhs_type) {
406+
(DataType::Null, DataType::Null) => Ok(DataType::Utf8),
407+
(DataType::Null, data_type) | (data_type, DataType::Null) => {
408+
coercion_rule(data_type, &DataType::Utf8)
409+
}
410+
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
411+
coercion_rule(lhs, rhs)
412+
}
413+
(DataType::Dictionary(_, v), other)
414+
| (other, DataType::Dictionary(_, v)) => coercion_rule(v, other),
415+
_ => {
416+
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
417+
Ok(coerced_type)
418+
} else {
419+
plan_err!(
420+
"{} and {} are not coercible to a common string type",
421+
lhs_type,
422+
rhs_type
423+
)
424+
}
425+
}
426+
}
427+
}
428+
429+
// Length checked above, safe to unwrap
430+
let mut coerced_type = current_types.first().unwrap().to_owned();
431+
for t in current_types.iter().skip(1) {
432+
coerced_type = coercion_rule(&coerced_type, t)?;
433+
}
434+
435+
fn base_type_or_default_type(data_type: &DataType) -> DataType {
436+
if data_type.is_null() {
437+
DataType::Utf8
438+
} else if let DataType::Dictionary(_, v) = data_type {
439+
base_type_or_default_type(v)
440+
} else {
441+
data_type.to_owned()
442+
}
443+
}
444+
445+
vec![vec![base_type_or_default_type(&coerced_type); *number]]
446+
}
384447
TypeSignature::Numeric(number) => {
385448
if *number < 1 {
386449
return plan_err!(

datafusion/functions/src/macros.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ macro_rules! make_math_binary_udf {
284284
use arrow::datatypes::DataType;
285285
use datafusion_common::{exec_err, DataFusionError, Result};
286286
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
287-
use datafusion_expr::TypeSignature::*;
287+
use datafusion_expr::TypeSignature;
288288
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
289289

290290
#[derive(Debug)]
@@ -298,8 +298,8 @@ macro_rules! make_math_binary_udf {
298298
Self {
299299
signature: Signature::one_of(
300300
vec![
301-
Exact(vec![Float32, Float32]),
302-
Exact(vec![Float64, Float64]),
301+
TypeSignature::Exact(vec![Float32, Float32]),
302+
TypeSignature::Exact(vec![Float64, Float64]),
303303
],
304304
Volatility::Immutable,
305305
),

datafusion/functions/src/math/nans.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@
1919
2020
use arrow::datatypes::DataType;
2121
use datafusion_common::{exec_err, DataFusionError, Result};
22-
use datafusion_expr::ColumnarValue;
22+
use datafusion_expr::{ColumnarValue, TypeSignature};
2323

2424
use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array};
25-
use datafusion_expr::TypeSignature::*;
2625
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2726
use std::any::Any;
2827
use std::sync::Arc;
@@ -43,7 +42,10 @@ impl IsNanFunc {
4342
use DataType::*;
4443
Self {
4544
signature: Signature::one_of(
46-
vec![Exact(vec![Float32]), Exact(vec![Float64])],
45+
vec![
46+
TypeSignature::Exact(vec![Float32]),
47+
TypeSignature::Exact(vec![Float64]),
48+
],
4749
Volatility::Immutable,
4850
),
4951
}

datafusion/functions/src/math/power.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ use datafusion_common::{
2525
};
2626
use datafusion_expr::expr::ScalarFunction;
2727
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
28-
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF};
28+
use datafusion_expr::{ColumnarValue, Expr, ScalarUDF, TypeSignature};
2929

3030
use arrow::array::{ArrayRef, Float64Array, Int64Array};
31-
use datafusion_expr::TypeSignature::*;
3231
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3332
use std::any::Any;
3433
use std::sync::Arc;
@@ -52,7 +51,10 @@ impl PowerFunc {
5251
use DataType::*;
5352
Self {
5453
signature: Signature::one_of(
55-
vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])],
54+
vec![
55+
TypeSignature::Exact(vec![Int64, Int64]),
56+
TypeSignature::Exact(vec![Float64, Float64]),
57+
],
5658
Volatility::Immutable,
5759
),
5860
aliases: vec![String::from("pow")],

datafusion/functions/src/regex/regexplike.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ use datafusion_common::{
2626
cast::as_generic_string_array, internal_err, DataFusionError, Result,
2727
};
2828
use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX;
29-
use datafusion_expr::TypeSignature::*;
30-
use datafusion_expr::{ColumnarValue, Documentation};
29+
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
3130
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3231
use std::any::Any;
3332
use std::sync::{Arc, OnceLock};
@@ -87,10 +86,10 @@ impl RegexpLikeFunc {
8786
Self {
8887
signature: Signature::one_of(
8988
vec![
90-
Exact(vec![Utf8, Utf8]),
91-
Exact(vec![LargeUtf8, LargeUtf8]),
92-
Exact(vec![Utf8, Utf8, Utf8]),
93-
Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
89+
TypeSignature::Exact(vec![Utf8, Utf8]),
90+
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
91+
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
92+
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
9493
],
9594
Volatility::Immutable,
9695
),

datafusion/functions/src/regex/regexpmatch.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ use datafusion_common::{arrow_datafusion_err, plan_err};
2626
use datafusion_common::{
2727
cast::as_generic_string_array, internal_err, DataFusionError, Result,
2828
};
29-
use datafusion_expr::ColumnarValue;
30-
use datafusion_expr::TypeSignature::*;
29+
use datafusion_expr::{ColumnarValue, TypeSignature};
3130
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3231
use std::any::Any;
3332
use std::sync::Arc;
@@ -53,10 +52,10 @@ impl RegexpMatchFunc {
5352
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`.
5453
// If that fails, it proceeds to `(LargeUtf8, Utf8)`.
5554
// TODO: Native support Utf8View for regexp_match.
56-
Exact(vec![Utf8, Utf8]),
57-
Exact(vec![LargeUtf8, LargeUtf8]),
58-
Exact(vec![Utf8, Utf8, Utf8]),
59-
Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
55+
TypeSignature::Exact(vec![Utf8, Utf8]),
56+
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
57+
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
58+
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
6059
],
6160
Volatility::Immutable,
6261
),

datafusion/functions/src/regex/regexpreplace.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion_common::{
3333
};
3434
use datafusion_expr::function::Hint;
3535
use datafusion_expr::ColumnarValue;
36-
use datafusion_expr::TypeSignature::*;
36+
use datafusion_expr::TypeSignature;
3737
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
3838
use regex::Regex;
3939
use std::any::Any;
@@ -56,10 +56,10 @@ impl RegexpReplaceFunc {
5656
Self {
5757
signature: Signature::one_of(
5858
vec![
59-
Exact(vec![Utf8, Utf8, Utf8]),
60-
Exact(vec![Utf8View, Utf8, Utf8]),
61-
Exact(vec![Utf8, Utf8, Utf8, Utf8]),
62-
Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
59+
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
60+
TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]),
61+
TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]),
62+
TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]),
6363
],
6464
Volatility::Immutable,
6565
),

datafusion/functions/src/string/ascii.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,8 @@ impl Default for AsciiFunc {
3939

4040
impl AsciiFunc {
4141
pub fn new() -> Self {
42-
use DataType::*;
4342
Self {
44-
signature: Signature::uniform(
45-
1,
46-
vec![Utf8, LargeUtf8, Utf8View],
47-
Volatility::Immutable,
48-
),
43+
signature: Signature::string(1, Volatility::Immutable),
4944
}
5045
}
5146
}

datafusion/functions/src/string/bit_length.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,8 @@ impl Default for BitLengthFunc {
3939

4040
impl BitLengthFunc {
4141
pub fn new() -> Self {
42-
use DataType::*;
4342
Self {
44-
signature: Signature::uniform(
45-
1,
46-
vec![Utf8, LargeUtf8],
47-
Volatility::Immutable,
48-
),
43+
signature: Signature::string(1, Volatility::Immutable),
4944
}
5045
}
5146
}

datafusion/functions/src/string/btrim.rs

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ use arrow::datatypes::DataType;
2222
use datafusion_common::{exec_err, Result};
2323
use datafusion_expr::function::Hint;
2424
use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING;
25-
use datafusion_expr::TypeSignature::*;
26-
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
27-
use datafusion_expr::{ScalarUDFImpl, Signature};
25+
use datafusion_expr::{
26+
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility,
27+
};
2828
use std::any::Any;
2929
use std::sync::OnceLock;
3030

@@ -49,18 +49,9 @@ impl Default for BTrimFunc {
4949

5050
impl BTrimFunc {
5151
pub fn new() -> Self {
52-
use DataType::*;
5352
Self {
5453
signature: Signature::one_of(
55-
vec![
56-
// Planner attempts coercion to the target type starting with the most preferred candidate.
57-
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
58-
// If that fails, it proceeds to `(Utf8, Utf8)`.
59-
Exact(vec![Utf8View, Utf8View]),
60-
Exact(vec![Utf8, Utf8]),
61-
Exact(vec![Utf8View]),
62-
Exact(vec![Utf8]),
63-
],
54+
vec![TypeSignature::String(2), TypeSignature::String(1)],
6455
Volatility::Immutable,
6556
),
6657
aliases: vec![String::from("trim")],

0 commit comments

Comments
 (0)