Skip to content

Migrate the internal and testing functions to invoke_with_args #14693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions datafusion-examples/examples/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use arrow::datatypes::DataType;
use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::common::{assert_batches_eq, Result, ScalarValue};
use datafusion::logical_expr::{
BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::optimizer::ApplyOrder;
use datafusion::optimizer::{OptimizerConfig, OptimizerRule};
Expand Down Expand Up @@ -205,11 +205,7 @@ impl ScalarUDFImpl for MyEq {
Ok(DataType::Boolean)
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
// this example simply returns "true" which is not what a real
// implementation would do.
Ok(ColumnarValue::Scalar(ScalarValue::from(true)))
Expand Down
12 changes: 5 additions & 7 deletions datafusion/core/tests/fuzz_cases/equivalence/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::utils::{compare_rows, get_row_at_idx};
use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
Expand Down Expand Up @@ -581,12 +583,8 @@ impl ScalarUDFImpl for TestScalarUDF {
Ok(input[0].sort_properties)
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new({
Expand Down
12 changes: 1 addition & 11 deletions datafusion/core/tests/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ use datafusion_common::Result;
use datafusion_common::{JoinSide, JoinType, ScalarValue};
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{
ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr::{Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility};
use datafusion_physical_expr::expressions::{
binary, col, BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr,
};
Expand Down Expand Up @@ -91,14 +89,6 @@ impl ScalarUDFImpl for DummyUDF {
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}

#[test]
Expand Down
53 changes: 14 additions & 39 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::{as_float64_array, as_int32_array};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::utils::take_function_args;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err,
not_impl_err, plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue,
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err,
plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue,
};
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder,
OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
Expand Down Expand Up @@ -207,11 +208,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
Ok(self.return_type.clone())
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}
}
Expand Down Expand Up @@ -518,16 +515,13 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
Ok(self.return_type.clone())
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
let answer = match &args[0] {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [arg] = take_function_args(self.name(), &args.args)?;
let answer = match arg {
// When called with static arguments, the result is returned as an array.
ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => {
let mut answer = vec![];
for index in 1..=number_rows {
for index in 1..=args.number_rows {
// When calling a function with immutable arguments, the result is returned with ")".
// Example: SELECT add_index_to_string('const_value') FROM table;
answer.push(index.to_string() + ") " + value);
Expand Down Expand Up @@ -713,14 +707,6 @@ impl ScalarUDFImpl for CastToI64UDF {
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!("Function should have been simplified prior to evaluation")
}
}

#[tokio::test]
Expand Down Expand Up @@ -850,17 +836,14 @@ impl ScalarUDFImpl for TakeUDF {
}

// The actual implementation
fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
let take_idx = match &args[2] {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let [_arg0, _arg1, arg2] = take_function_args(self.name(), &args.args)?;
let take_idx = match arg2 {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "0" => 0,
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) if v == "1" => 1,
_ => unreachable!(),
};
match &args[take_idx] {
match &args.args[take_idx] {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())),
ColumnarValue::Scalar(_) => unimplemented!(),
}
Expand Down Expand Up @@ -963,14 +946,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
Ok(self.return_type.clone())
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
internal_err!("This function should not get invoked!")
}

fn simplify(
&self,
args: Vec<Expr>,
Expand Down
10 changes: 3 additions & 7 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::function::{
};
use crate::{
conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator,
AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
};
use crate::{
Expand Down Expand Up @@ -477,12 +477,8 @@ impl ScalarUDFImpl for SimpleScalarUDF {
Ok(self.return_type.clone())
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
(self.fun)(args)
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
(self.fun)(&args.args)
}
}

Expand Down
9 changes: 2 additions & 7 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -899,13 +899,8 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.return_type_from_args(args)
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
number_rows: usize,
) -> Result<ColumnarValue> {
#[allow(deprecated)]
self.inner.invoke_batch(args, number_rows)
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
self.inner.invoke_with_args(args)
}

fn simplify(
Expand Down
10 changes: 3 additions & 7 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,8 @@ mod test {
use datafusion_expr::{
cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF,
BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan,
Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery,
Volatility,
Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
SimpleAggregateUDF, Subquery, Volatility,
};
use datafusion_functions_aggregate::average::AvgAccumulator;

Expand Down Expand Up @@ -1266,11 +1266,7 @@ mod test {
Ok(Utf8)
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::from("a")))
}
}
Expand Down
8 changes: 0 additions & 8 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1703,13 +1703,5 @@ mod test {
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!()
}
}
}
11 changes: 4 additions & 7 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1386,8 +1386,9 @@ mod tests {
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
col, in_list, in_subquery, lit, ColumnarValue, ExprFunctionExt, Extension,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType,
UserDefinedLogicalNodeCore, Volatility, WindowFunctionDefinition,
LogicalPlanBuilder, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
TableSource, TableType, UserDefinedLogicalNodeCore, Volatility,
WindowFunctionDefinition,
};

use crate::optimizer::Optimizer;
Expand Down Expand Up @@ -3615,11 +3616,7 @@ Projection: a, b
Ok(DataType::Int32)
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::from(1)))
}
}
Expand Down
12 changes: 5 additions & 7 deletions datafusion/physical-expr/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ pub(crate) mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{exec_err, DataFusionError, ScalarValue};
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};

use petgraph::visit::Bfs;

Expand Down Expand Up @@ -309,12 +311,8 @@ pub(crate) mod tests {
Ok(input[0].sort_properties)
}

fn invoke_batch(
&self,
args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(args)?;
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = ColumnarValue::values_to_arrays(&args.args)?;

let arr: ArrayRef = match args[0].data_type() {
DataType::Float64 => Arc::new({
Expand Down
11 changes: 2 additions & 9 deletions datafusion/proto/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use std::fmt::Debug;
use datafusion_common::plan_err;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ColumnarValue, PartitionEvaluator, ScalarUDFImpl,
Signature, Volatility, WindowUDFImpl,
Accumulator, AggregateUDFImpl, PartitionEvaluator, ScalarUDFImpl, Signature,
Volatility, WindowUDFImpl,
};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
Expand Down Expand Up @@ -69,13 +69,6 @@ impl ScalarUDFImpl for MyRegexUdf {
plan_err!("regex_udf only accepts Utf8 arguments")
}
}
fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> datafusion_common::Result<ColumnarValue> {
unimplemented!()
}
fn aliases(&self) -> &[String] {
&self.aliases
}
Expand Down
12 changes: 2 additions & 10 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1655,8 +1655,8 @@ mod tests {
use datafusion_expr::{
case, cast, col, cube, exists, grouping_set, interval_datetime_lit,
interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup,
table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl,
Signature, Volatility, WindowFrame, WindowFunctionDefinition,
table_scan, try_cast, when, wildcard, ScalarUDF, ScalarUDFImpl, Signature,
Volatility, WindowFrame, WindowFunctionDefinition,
};
use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt};
use datafusion_functions::expr_fn::{get_field, named_struct};
Expand Down Expand Up @@ -1705,14 +1705,6 @@ mod tests {
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}
// See sql::tests for E2E tests.

Expand Down
12 changes: 2 additions & 10 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use datafusion_expr::{
col,
logical_plan::{LogicalPlan, Prepare},
test::function_stub::sum_udaf,
ColumnarValue, CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature,
Statement, Volatility,
CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature, Statement,
Volatility,
};
use datafusion_functions::{string, unicode};
use datafusion_sql::{
Expand Down Expand Up @@ -2646,14 +2646,6 @@ impl ScalarUDFImpl for DummyUDF {
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke_batch(
&self,
_args: &[ColumnarValue],
_number_rows: usize,
) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}

/// Create logical plan, write with formatter, compare to expected output
Expand Down
Loading