Skip to content

Commit ae73371

Browse files
committed
Added support for ScalarUDFImpl::invoke_with_return_type where the invoke is passed the return type created for the udf instance
1 parent 398d5f6 commit ae73371

File tree

13 files changed

+89
-41
lines changed

13 files changed

+89
-41
lines changed

datafusion/expr/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
9292
pub use udaf::{
9393
aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs,
9494
};
95-
pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl};
95+
pub use udf::{scalar_doc_sections, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
9696
pub use udf_docs::{DocSection, Documentation, DocumentationBuilder};
9797
pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl};
9898
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

datafusion/expr/src/udf.rs

+39-7
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,6 @@ impl ScalarUDF {
203203
self.inner.simplify(args, info)
204204
}
205205

206-
/// Invoke the function on `args`, returning the appropriate result.
207-
///
208-
/// See [`ScalarUDFImpl::invoke`] for more details.
209206
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
210207
pub fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
211208
#[allow(deprecated)]
@@ -216,17 +213,23 @@ impl ScalarUDF {
216213
self.inner.is_nullable(args, schema)
217214
}
218215

219-
/// Invoke the function with `args` and number of rows, returning the appropriate result.
220-
///
221-
/// See [`ScalarUDFImpl::invoke_batch`] for more details.
216+
#[deprecated(since = "43.0.0", note = "Use `invoke_batch` instead")]
222217
pub fn invoke_batch(
223218
&self,
224219
args: &[ColumnarValue],
225220
number_rows: usize,
226221
) -> Result<ColumnarValue> {
222+
#[allow(deprecated)]
227223
self.inner.invoke_batch(args, number_rows)
228224
}
229225

226+
/// Invoke the function on `args`, returning the appropriate result.
227+
///
228+
/// See [`ScalarUDFImpl::invoke_with_args`] for more details.
229+
pub fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
230+
self.inner.invoke_with_args(args)
231+
}
232+
230233
/// Invoke the function without `args` but number of rows, returning the appropriate result.
231234
///
232235
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
@@ -324,6 +327,16 @@ where
324327
}
325328
}
326329

330+
pub struct ScalarFunctionArgs<'a> {
331+
// The evaluated arguments to the function
332+
pub args: &'a [ColumnarValue],
333+
// The number of rows in record batch being evaluated
334+
pub number_rows: usize,
335+
// The return type of the scalar function returned (from `return_type` or `return_type_from_exprs`)
336+
// when creating the physical expression from the logical expression
337+
pub return_type: &'a DataType,
338+
}
339+
327340
/// Trait for implementing [`ScalarUDF`].
328341
///
329342
/// This trait exposes the full API for implementing user defined functions and
@@ -356,7 +369,7 @@ where
356369
/// }
357370
/// }
358371
/// }
359-
///
372+
///
360373
/// static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
361374
///
362375
/// fn get_doc() -> &'static Documentation {
@@ -518,6 +531,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
518531
///
519532
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
520533
/// to arrays, which will likely be simpler code, but be slower.
534+
#[deprecated(since = "43.0.0", note = "Use `invoke_with_args` instead")]
521535
fn invoke_batch(
522536
&self,
523537
args: &[ColumnarValue],
@@ -537,6 +551,23 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
537551
}
538552
}
539553

554+
/// Invoke the function with `args: ScalarFunctionArgs` returning the appropriate result.
555+
///
556+
/// The function will be invoked with a struct `ScalarFunctionArgs`
557+
///
558+
/// # Performance
559+
///
560+
/// For the best performance, the implementations should handle the common case
561+
/// when one or more of their arguments are constant values (aka
562+
/// [`ColumnarValue::Scalar`]).
563+
///
564+
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
565+
/// to arrays, which will likely be simpler code, but be slower.
566+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
567+
#[allow(deprecated)]
568+
self.invoke_batch(args.args, args.number_rows)
569+
}
570+
540571
/// Invoke the function without `args`, instead the number of rows are provided,
541572
/// returning the appropriate result.
542573
#[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")]
@@ -767,6 +798,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
767798
args: &[ColumnarValue],
768799
number_rows: usize,
769800
) -> Result<ColumnarValue> {
801+
#[allow(deprecated)]
770802
self.inner.invoke_batch(args, number_rows)
771803
}
772804

datafusion/functions/benches/random.rs

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) {
2929
c.bench_function("random_1M_rows_batch_8192", |b| {
3030
b.iter(|| {
3131
for _ in 0..iterations {
32+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
3233
black_box(random_func.invoke_batch(&[], 8192).unwrap());
3334
}
3435
})
@@ -39,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) {
3940
c.bench_function("random_1M_rows_batch_128", |b| {
4041
b.iter(|| {
4142
for _ in 0..iterations_128 {
43+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
4244
black_box(random_func.invoke_batch(&[], 128).unwrap());
4345
}
4446
})

datafusion/functions/src/core/version.rs

+1
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ mod test {
121121
#[tokio::test]
122122
async fn test_version_udf() {
123123
let version_udf = ScalarUDF::from(VersionFunc::new());
124+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
124125
let version = version_udf.invoke_batch(&[], 1).unwrap();
125126

126127
if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version {

datafusion/functions/src/datetime/to_local_time.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ mod tests {
431431
use arrow::datatypes::{DataType, TimeUnit};
432432
use chrono::NaiveDateTime;
433433
use datafusion_common::ScalarValue;
434-
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
434+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
435435

436436
use super::{adjust_to_local_time, ToLocalTimeFunc};
437437

@@ -558,7 +558,11 @@ mod tests {
558558

559559
fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) {
560560
let res = ToLocalTimeFunc::new()
561-
.invoke_batch(&[ColumnarValue::Scalar(input)], 1)
561+
.invoke_with_args(ScalarFunctionArgs {
562+
args: &[ColumnarValue::Scalar(input)],
563+
number_rows: 1,
564+
return_type: &expected.data_type(),
565+
})
562566
.unwrap();
563567
match res {
564568
ColumnarValue::Scalar(res) => {
@@ -617,6 +621,7 @@ mod tests {
617621
.map(|s| Some(string_to_timestamp_nanos(s).unwrap()))
618622
.collect::<TimestampNanosecondArray>();
619623
let batch_size = input.len();
624+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
620625
let result = ToLocalTimeFunc::new()
621626
.invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size)
622627
.unwrap();

datafusion/functions/src/datetime/to_timestamp.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ mod tests {
10081008
for array in arrays {
10091009
let rt = udf.return_type(&[array.data_type()]).unwrap();
10101010
assert!(matches!(rt, Timestamp(_, Some(_))));
1011-
1011+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
10121012
let res = udf
10131013
.invoke_batch(&[array.clone()], 1)
10141014
.expect("that to_timestamp parsed values without error");
@@ -1051,7 +1051,7 @@ mod tests {
10511051
for array in arrays {
10521052
let rt = udf.return_type(&[array.data_type()]).unwrap();
10531053
assert!(matches!(rt, Timestamp(_, None)));
1054-
1054+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
10551055
let res = udf
10561056
.invoke_batch(&[array.clone()], 1)
10571057
.expect("that to_timestamp parsed values without error");

datafusion/functions/src/datetime/to_unixtime.rs

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ impl ScalarUDFImpl for ToUnixtimeFunc {
8383
DataType::Date64 | DataType::Date32 | DataType::Timestamp(_, None) => args[0]
8484
.cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)?
8585
.cast_to(&DataType::Int64, None),
86+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
8687
DataType::Utf8 => ToTimestampSecondsFunc::new()
8788
.invoke_batch(args, batch_size)?
8889
.cast_to(&DataType::Int64, None),

datafusion/functions/src/math/log.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ mod tests {
277277
]))), // num
278278
ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))),
279279
];
280-
280+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
281281
let _ = LogFunc::new().invoke_batch(&args, 4);
282282
}
283283

@@ -286,7 +286,7 @@ mod tests {
286286
let args = [
287287
ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num
288288
];
289-
289+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
290290
let result = LogFunc::new().invoke_batch(&args, 1);
291291
result.expect_err("expected error");
292292
}
@@ -296,7 +296,7 @@ mod tests {
296296
let args = [
297297
ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num
298298
];
299-
299+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
300300
let result = LogFunc::new()
301301
.invoke_batch(&args, 1)
302302
.expect("failed to initialize function log");
@@ -320,7 +320,7 @@ mod tests {
320320
let args = [
321321
ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num
322322
];
323-
323+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
324324
let result = LogFunc::new()
325325
.invoke_batch(&args, 1)
326326
.expect("failed to initialize function log");
@@ -345,7 +345,7 @@ mod tests {
345345
ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num
346346
ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num
347347
];
348-
348+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
349349
let result = LogFunc::new()
350350
.invoke_batch(&args, 1)
351351
.expect("failed to initialize function log");
@@ -370,7 +370,7 @@ mod tests {
370370
ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num
371371
ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num
372372
];
373-
373+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
374374
let result = LogFunc::new()
375375
.invoke_batch(&args, 1)
376376
.expect("failed to initialize function log");
@@ -396,7 +396,7 @@ mod tests {
396396
10.0, 100.0, 1000.0, 10000.0,
397397
]))), // num
398398
];
399-
399+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
400400
let result = LogFunc::new()
401401
.invoke_batch(&args, 4)
402402
.expect("failed to initialize function log");
@@ -425,7 +425,7 @@ mod tests {
425425
10.0, 100.0, 1000.0, 10000.0,
426426
]))), // num
427427
];
428-
428+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
429429
let result = LogFunc::new()
430430
.invoke_batch(&args, 4)
431431
.expect("failed to initialize function log");
@@ -455,7 +455,7 @@ mod tests {
455455
8.0, 4.0, 81.0, 625.0,
456456
]))), // num
457457
];
458-
458+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
459459
let result = LogFunc::new()
460460
.invoke_batch(&args, 4)
461461
.expect("failed to initialize function log");
@@ -485,7 +485,7 @@ mod tests {
485485
8.0, 4.0, 81.0, 625.0,
486486
]))), // num
487487
];
488-
488+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
489489
let result = LogFunc::new()
490490
.invoke_batch(&args, 4)
491491
.expect("failed to initialize function log");

datafusion/functions/src/math/power.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ mod tests {
205205
ColumnarValue::Array(Arc::new(Float64Array::from(vec![2.0, 2.0, 3.0, 5.0]))), // base
206206
ColumnarValue::Array(Arc::new(Float64Array::from(vec![3.0, 2.0, 4.0, 4.0]))), // exponent
207207
];
208-
208+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
209209
let result = PowerFunc::new()
210210
.invoke_batch(&args, 4)
211211
.expect("failed to initialize function power");
@@ -232,7 +232,7 @@ mod tests {
232232
ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base
233233
ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent
234234
];
235-
235+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
236236
let result = PowerFunc::new()
237237
.invoke_batch(&args, 4)
238238
.expect("failed to initialize function power");

datafusion/functions/src/math/signum.rs

+2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ mod test {
167167
f32::NEG_INFINITY,
168168
]));
169169
let batch_size = array.len();
170+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
170171
let result = SignumFunc::new()
171172
.invoke_batch(&[ColumnarValue::Array(array)], batch_size)
172173
.expect("failed to initialize function signum");
@@ -207,6 +208,7 @@ mod test {
207208
f64::NEG_INFINITY,
208209
]));
209210
let batch_size = array.len();
211+
#[allow(deprecated)] // TODO: migrate to invoke_with_args
210212
let result = SignumFunc::new()
211213
.invoke_batch(&[ColumnarValue::Array(array)], batch_size)
212214
.expect("failed to initialize function signum");

0 commit comments

Comments
 (0)