Skip to content

Commit 822a368

Browse files
zjregeeyurunjie
authored and
yurunjie
committed
migrate string functions to inovke_with_args
1 parent 19fe44c commit 822a368

21 files changed

+150
-186
lines changed

datafusion/functions/src/string/ascii.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use arrow::error::ArrowError;
2222
use datafusion_common::types::logical_string;
2323
use datafusion_common::{internal_err, Result};
2424
use datafusion_expr::{ColumnarValue, Documentation, TypeSignatureClass};
25-
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
25+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
2626
use datafusion_expr_common::signature::Coercion;
2727
use datafusion_macros::user_doc;
2828
use std::any::Any;
@@ -92,12 +92,8 @@ impl ScalarUDFImpl for AsciiFunc {
9292
Ok(Int32)
9393
}
9494

95-
fn invoke_batch(
96-
&self,
97-
args: &[ColumnarValue],
98-
_number_rows: usize,
99-
) -> Result<ColumnarValue> {
100-
make_scalar_function(ascii, vec![])(args)
95+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
96+
make_scalar_function(ascii, vec![])(&args.args)
10197
}
10298

10399
fn documentation(&self) -> Option<&Documentation> {

datafusion/functions/src/string/bit_length.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::any::Any;
2222
use crate::utils::utf8_to_int_type;
2323
use datafusion_common::{utils::take_function_args, Result, ScalarValue};
2424
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
25-
use datafusion_expr::{ScalarUDFImpl, Signature};
25+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
2626
use datafusion_macros::user_doc;
2727

2828
#[user_doc(
@@ -77,12 +77,8 @@ impl ScalarUDFImpl for BitLengthFunc {
7777
utf8_to_int_type(&arg_types[0], "bit_length")
7878
}
7979

80-
fn invoke_batch(
81-
&self,
82-
args: &[ColumnarValue],
83-
_number_rows: usize,
84-
) -> Result<ColumnarValue> {
85-
let [array] = take_function_args(self.name(), args)?;
80+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
81+
let [array] = take_function_args(self.name(), &args.args)?;
8682

8783
match array {
8884
ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)),

datafusion/functions/src/string/btrim.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ use arrow::datatypes::DataType;
2222
use datafusion_common::{exec_err, Result};
2323
use datafusion_expr::function::Hint;
2424
use datafusion_expr::{
25-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility,
25+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
26+
TypeSignature, Volatility,
2627
};
2728
use datafusion_macros::user_doc;
2829
use std::any::Any;
@@ -101,20 +102,16 @@ impl ScalarUDFImpl for BTrimFunc {
101102
}
102103
}
103104

104-
fn invoke_batch(
105-
&self,
106-
args: &[ColumnarValue],
107-
_number_rows: usize,
108-
) -> Result<ColumnarValue> {
109-
match args[0].data_type() {
105+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106+
match args.args[0].data_type() {
110107
DataType::Utf8 | DataType::Utf8View => make_scalar_function(
111108
btrim::<i32>,
112109
vec![Hint::Pad, Hint::AcceptsSingular],
113-
)(args),
110+
)(&args.args),
114111
DataType::LargeUtf8 => make_scalar_function(
115112
btrim::<i64>,
116113
vec![Hint::Pad, Hint::AcceptsSingular],
117-
)(args),
114+
)(&args.args),
118115
other => exec_err!(
119116
"Unsupported data type {other:?} for function btrim,\
120117
expected Utf8, LargeUtf8 or Utf8View."

datafusion/functions/src/string/chr.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use crate::utils::make_scalar_function;
2828
use datafusion_common::cast::as_int64_array;
2929
use datafusion_common::{exec_err, Result};
3030
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
31-
use datafusion_expr::{ScalarUDFImpl, Signature};
31+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
3232
use datafusion_macros::user_doc;
3333

3434
/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character.
@@ -111,12 +111,8 @@ impl ScalarUDFImpl for ChrFunc {
111111
Ok(Utf8)
112112
}
113113

114-
fn invoke_batch(
115-
&self,
116-
args: &[ColumnarValue],
117-
_number_rows: usize,
118-
) -> Result<ColumnarValue> {
119-
make_scalar_function(chr, vec![])(args)
114+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
115+
make_scalar_function(chr, vec![])(&args.args)
120116
}
121117

122118
fn documentation(&self) -> Option<&Documentation> {

datafusion/functions/src/string/concat.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
3030
use datafusion_expr::expr::ScalarFunction;
3131
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
3232
use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33-
use datafusion_expr::{ScalarUDFImpl, Signature};
33+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
3434
use datafusion_macros::user_doc;
3535

3636
#[user_doc(
@@ -105,11 +105,9 @@ impl ScalarUDFImpl for ConcatFunc {
105105

106106
/// Concatenates the text representations of all the arguments. NULL arguments are ignored.
107107
/// concat('abcde', 2, NULL, 22) = 'abcde222'
108-
fn invoke_batch(
109-
&self,
110-
args: &[ColumnarValue],
111-
_number_rows: usize,
112-
) -> Result<ColumnarValue> {
108+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109+
let ScalarFunctionArgs { args, .. } = args;
110+
113111
let mut return_datatype = DataType::Utf8;
114112
args.iter().for_each(|col| {
115113
if col.data_type() == DataType::Utf8View {
@@ -169,7 +167,7 @@ impl ScalarUDFImpl for ConcatFunc {
169167
let mut data_size = 0;
170168
let mut columns = Vec::with_capacity(args.len());
171169

172-
for arg in args {
170+
for arg in &args {
173171
match arg {
174172
ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
175173
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
@@ -470,10 +468,14 @@ mod tests {
470468
None,
471469
Some("b"),
472470
])));
473-
let args = &[c0, c1, c2, c3, c4];
474471

475-
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
476-
let result = ConcatFunc::new().invoke_batch(args, 3)?;
472+
let args = ScalarFunctionArgs {
473+
args: vec![c0, c1, c2, c3, c4],
474+
number_rows: 3,
475+
return_type: &Utf8,
476+
};
477+
478+
let result = ConcatFunc::new().invoke_with_args(args)?;
477479
let expected =
478480
Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
479481
as ArrayRef;

datafusion/functions/src/string/concat_ws.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue};
3030
use datafusion_expr::expr::ScalarFunction;
3131
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
3232
use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33-
use datafusion_expr::{ScalarUDFImpl, Signature};
33+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
3434
use datafusion_macros::user_doc;
3535

3636
#[user_doc(
@@ -102,11 +102,9 @@ impl ScalarUDFImpl for ConcatWsFunc {
102102

103103
/// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored.
104104
/// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22'
105-
fn invoke_batch(
106-
&self,
107-
args: &[ColumnarValue],
108-
_number_rows: usize,
109-
) -> Result<ColumnarValue> {
105+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
106+
let ScalarFunctionArgs { args, .. } = args;
107+
110108
// do not accept 0 arguments.
111109
if args.len() < 2 {
112110
return exec_err!(
@@ -404,14 +402,15 @@ fn is_null(expr: &Expr) -> bool {
404402
#[cfg(test)]
405403
mod tests {
406404
use std::sync::Arc;
405+
use std::vec;
407406

408407
use arrow::array::{Array, ArrayRef, StringArray};
409408
use arrow::datatypes::DataType::Utf8;
410409

411410
use crate::string::concat_ws::ConcatWsFunc;
412411
use datafusion_common::Result;
413412
use datafusion_common::ScalarValue;
414-
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
413+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
415414

416415
use crate::utils::test::test_function;
417416

@@ -482,10 +481,14 @@ mod tests {
482481
None,
483482
Some("z"),
484483
])));
485-
let args = &[c0, c1, c2];
486484

487-
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
488-
let result = ConcatWsFunc::new().invoke_batch(args, 3)?;
485+
let args = ScalarFunctionArgs {
486+
args: vec![c0, c1, c2],
487+
number_rows: 3,
488+
return_type: &Utf8,
489+
};
490+
491+
let result = ConcatWsFunc::new().invoke_with_args(args)?;
489492
let expected =
490493
Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef;
491494
match &result {
@@ -508,10 +511,14 @@ mod tests {
508511
Some("y"),
509512
Some("z"),
510513
])));
511-
let args = &[c0, c1, c2];
512514

513-
#[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch
514-
let result = ConcatWsFunc::new().invoke_batch(args, 3)?;
515+
let args = ScalarFunctionArgs {
516+
args: vec![c0, c1, c2],
517+
number_rows: 3,
518+
return_type: &Utf8,
519+
};
520+
521+
let result = ConcatWsFunc::new().invoke_with_args(args)?;
515522
let expected =
516523
Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")]))
517524
as ArrayRef;

datafusion/functions/src/string/contains.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ use datafusion_common::exec_err;
2424
use datafusion_common::DataFusionError;
2525
use datafusion_common::Result;
2626
use datafusion_expr::{
27-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
27+
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28+
Volatility,
2829
};
2930
use datafusion_macros::user_doc;
3031
use std::any::Any;
@@ -81,12 +82,8 @@ impl ScalarUDFImpl for ContainsFunc {
8182
Ok(Boolean)
8283
}
8384

84-
fn invoke_batch(
85-
&self,
86-
args: &[ColumnarValue],
87-
_number_rows: usize,
88-
) -> Result<ColumnarValue> {
89-
make_scalar_function(contains, vec![])(args)
85+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86+
make_scalar_function(contains, vec![])(&args.args)
9087
}
9188

9289
fn documentation(&self) -> Option<&Documentation> {
@@ -125,8 +122,9 @@ pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
125122
mod test {
126123
use super::ContainsFunc;
127124
use arrow::array::{BooleanArray, StringArray};
125+
use arrow::datatypes::DataType;
128126
use datafusion_common::ScalarValue;
129-
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
127+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
130128
use std::sync::Arc;
131129

132130
#[test]
@@ -137,8 +135,14 @@ mod test {
137135
Some("yyy?()"),
138136
])));
139137
let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
140-
#[allow(deprecated)] // TODO migrate UDF to invoke
141-
let actual = udf.invoke_batch(&[array, scalar], 2).unwrap();
138+
139+
let args = ScalarFunctionArgs {
140+
args: vec![array, scalar],
141+
number_rows: 2,
142+
return_type: &DataType::Boolean,
143+
};
144+
145+
let actual = udf.invoke_with_args(args).unwrap();
142146
let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
143147
Some(true),
144148
Some(false),

datafusion/functions/src/string/ends_with.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use arrow::datatypes::DataType;
2424
use crate::utils::make_scalar_function;
2525
use datafusion_common::{internal_err, Result};
2626
use datafusion_expr::{ColumnarValue, Documentation, Volatility};
27-
use datafusion_expr::{ScalarUDFImpl, Signature};
27+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
2828
use datafusion_macros::user_doc;
2929

3030
#[user_doc(
@@ -84,14 +84,10 @@ impl ScalarUDFImpl for EndsWithFunc {
8484
Ok(DataType::Boolean)
8585
}
8686

87-
fn invoke_batch(
88-
&self,
89-
args: &[ColumnarValue],
90-
_number_rows: usize,
91-
) -> Result<ColumnarValue> {
92-
match args[0].data_type() {
87+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
88+
match args.args[0].data_type() {
9389
DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
94-
make_scalar_function(ends_with, vec![])(args)
90+
make_scalar_function(ends_with, vec![])(&args.args)
9591
}
9692
other => {
9793
internal_err!("Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View")?

datafusion/functions/src/string/levenshtein.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
2626
use datafusion_common::utils::datafusion_strsim;
2727
use datafusion_common::{exec_err, utils::take_function_args, Result};
2828
use datafusion_expr::{ColumnarValue, Documentation};
29-
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
3030
use datafusion_macros::user_doc;
3131

3232
#[user_doc(
@@ -86,16 +86,14 @@ impl ScalarUDFImpl for LevenshteinFunc {
8686
utf8_to_int_type(&arg_types[0], "levenshtein")
8787
}
8888

89-
fn invoke_batch(
90-
&self,
91-
args: &[ColumnarValue],
92-
_number_rows: usize,
93-
) -> Result<ColumnarValue> {
94-
match args[0].data_type() {
89+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
90+
match args.args[0].data_type() {
9591
DataType::Utf8View | DataType::Utf8 => {
96-
make_scalar_function(levenshtein::<i32>, vec![])(args)
92+
make_scalar_function(levenshtein::<i32>, vec![])(&args.args)
93+
}
94+
DataType::LargeUtf8 => {
95+
make_scalar_function(levenshtein::<i64>, vec![])(&args.args)
9796
}
98-
DataType::LargeUtf8 => make_scalar_function(levenshtein::<i64>, vec![])(args),
9997
other => {
10098
exec_err!("Unsupported data type {other:?} for function levenshtein")
10199
}

datafusion/functions/src/string/lower.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::string::common::to_lower;
2222
use crate::utils::utf8_to_str_type;
2323
use datafusion_common::Result;
2424
use datafusion_expr::{ColumnarValue, Documentation};
25-
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
25+
use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
2626
use datafusion_macros::user_doc;
2727

2828
#[user_doc(
@@ -77,12 +77,8 @@ impl ScalarUDFImpl for LowerFunc {
7777
utf8_to_str_type(&arg_types[0], "lower")
7878
}
7979

80-
fn invoke_batch(
81-
&self,
82-
args: &[ColumnarValue],
83-
_number_rows: usize,
84-
) -> Result<ColumnarValue> {
85-
to_lower(args, "lower")
80+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
81+
to_lower(&args.args, "lower")
8682
}
8783

8884
fn documentation(&self) -> Option<&Documentation> {
@@ -98,10 +94,14 @@ mod tests {
9894

9995
fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> {
10096
let func = LowerFunc::new();
101-
let batch_len = input.len();
102-
let args = vec![ColumnarValue::Array(input)];
103-
#[allow(deprecated)] // TODO migrate UDF to invoke
104-
let result = match func.invoke_batch(&args, batch_len)? {
97+
98+
let args = ScalarFunctionArgs {
99+
number_rows: input.len(),
100+
args: vec![ColumnarValue::Array(input)],
101+
return_type: &DataType::Utf8,
102+
};
103+
104+
let result = match func.invoke_with_args(args)? {
105105
ColumnarValue::Array(result) => result,
106106
_ => unreachable!("lower"),
107107
};

0 commit comments

Comments
 (0)