Skip to content

Commit 5239d1a

Browse files
findepialambmbrobbelcomphead
authored
Validate and unpack function arguments tersely (#14513)
* Validate and unpack function arguments tersely Add a `take_function_args` helper that provides convenient unpacking of function arguments along with validation that the provided argument count matches the expected. A few functions are updated to leverage the new pattern to demonstrate its usefulness. * Add example in rust doc Co-authored-by: Andrew Lamb <[email protected]> * fix fmt * Export function utils publicly this exports only the newly added take_function_args function. all other utils members are pub(crate) * use compact format pattern Co-authored-by: Matthijs Brobbel <[email protected]> * fix example * fixup! fix example * fix license header Co-authored-by: Oleks V <[email protected]> * Name args in nvl2 and use take_function_args in execution too --------- Co-authored-by: Andrew Lamb <[email protected]> Co-authored-by: Matthijs Brobbel <[email protected]> Co-authored-by: Oleks V <[email protected]>
1 parent 304488d commit 5239d1a

File tree

14 files changed

+132
-146
lines changed

14 files changed

+132
-146
lines changed

datafusion/functions/src/core/arrow_cast.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion_common::{
2525
use datafusion_common::{exec_datafusion_err, DataFusionError};
2626
use std::any::Any;
2727

28+
use crate::utils::take_function_args;
2829
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
2930
use datafusion_expr::{
3031
ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl,
@@ -117,10 +118,9 @@ impl ScalarUDFImpl for ArrowCastFunc {
117118
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
118119
let nullable = args.nullables.iter().any(|&nullable| nullable);
119120

120-
// Length check handled in the signature
121-
debug_assert_eq!(args.scalar_arguments.len(), 2);
121+
let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?;
122122

123-
args.scalar_arguments[1]
123+
type_arg
124124
.and_then(|sv| sv.try_as_str().flatten().filter(|s| !s.is_empty()))
125125
.map_or_else(
126126
|| {

datafusion/functions/src/core/arrowtypeof.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::take_function_args;
1819
use arrow::datatypes::DataType;
19-
use datafusion_common::{exec_err, Result, ScalarValue};
20+
use datafusion_common::{Result, ScalarValue};
2021
use datafusion_expr::{ColumnarValue, Documentation};
2122
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2223
use datafusion_macros::user_doc;
@@ -80,14 +81,8 @@ impl ScalarUDFImpl for ArrowTypeOfFunc {
8081
args: &[ColumnarValue],
8182
_number_rows: usize,
8283
) -> Result<ColumnarValue> {
83-
if args.len() != 1 {
84-
return exec_err!(
85-
"arrow_typeof function requires 1 arguments, got {}",
86-
args.len()
87-
);
88-
}
89-
90-
let input_data_type = args[0].data_type();
84+
let [arg] = take_function_args(self.name(), args)?;
85+
let input_data_type = arg.data_type();
9186
Ok(ColumnarValue::Scalar(ScalarValue::from(format!(
9287
"{input_data_type}"
9388
))))

datafusion/functions/src/core/getfield.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::take_function_args;
1819
use arrow::array::{
1920
make_array, Array, Capacities, MutableArrayData, Scalar, StringArray,
2021
};
@@ -99,14 +100,9 @@ impl ScalarUDFImpl for GetFieldFunc {
99100
}
100101

101102
fn display_name(&self, args: &[Expr]) -> Result<String> {
102-
if args.len() != 2 {
103-
return exec_err!(
104-
"get_field function requires 2 arguments, got {}",
105-
args.len()
106-
);
107-
}
103+
let [base, field_name] = take_function_args(self.name(), args)?;
108104

109-
let name = match &args[1] {
105+
let name = match field_name {
110106
Expr::Literal(name) => name,
111107
_ => {
112108
return exec_err!(
@@ -115,7 +111,7 @@ impl ScalarUDFImpl for GetFieldFunc {
115111
}
116112
};
117113

118-
Ok(format!("{}[{}]", args[0], name))
114+
Ok(format!("{base}[{name}]"))
119115
}
120116

121117
fn schema_name(&self, args: &[Expr]) -> Result<String> {

datafusion/functions/src/core/nullif.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
// under the License.
1717

1818
use arrow::datatypes::DataType;
19-
use datafusion_common::{exec_err, Result};
19+
use datafusion_common::Result;
2020
use datafusion_expr::{ColumnarValue, Documentation};
2121

22+
use crate::utils::take_function_args;
2223
use arrow::compute::kernels::cmp::eq;
2324
use arrow::compute::kernels::nullif::nullif;
2425
use datafusion_common::ScalarValue;
2526
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2627
use datafusion_macros::user_doc;
2728
use std::any::Any;
29+
2830
#[user_doc(
2931
doc_section(label = "Conditional Functions"),
3032
description = "Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_.
@@ -119,14 +121,7 @@ impl ScalarUDFImpl for NullIfFunc {
119121
/// 1 - if the left is equal to this expr2, then the result is NULL, otherwise left value is passed.
120122
///
121123
fn nullif_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
122-
if args.len() != 2 {
123-
return exec_err!(
124-
"{:?} args were supplied but NULLIF takes exactly two args",
125-
args.len()
126-
);
127-
}
128-
129-
let (lhs, rhs) = (&args[0], &args[1]);
124+
let [lhs, rhs] = take_function_args("nullif", args)?;
130125

131126
match (lhs, rhs) {
132127
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {

datafusion/functions/src/core/nvl.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::take_function_args;
1819
use arrow::array::Array;
1920
use arrow::compute::is_not_null;
2021
use arrow::compute::kernels::zip::zip;
2122
use arrow::datatypes::DataType;
22-
use datafusion_common::{internal_err, Result};
23+
use datafusion_common::Result;
2324
use datafusion_expr::{
2425
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2526
};
2627
use datafusion_macros::user_doc;
2728
use std::sync::Arc;
29+
2830
#[user_doc(
2931
doc_section(label = "Conditional Functions"),
3032
description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.",
@@ -133,13 +135,8 @@ impl ScalarUDFImpl for NVLFunc {
133135
}
134136

135137
fn nvl_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
136-
if args.len() != 2 {
137-
return internal_err!(
138-
"{:?} args were supplied but NVL/IFNULL takes exactly two args",
139-
args.len()
140-
);
141-
}
142-
let (lhs_array, rhs_array) = match (&args[0], &args[1]) {
138+
let [lhs, rhs] = take_function_args("nvl/ifnull", args)?;
139+
let (lhs_array, rhs_array) = match (lhs, rhs) {
143140
(ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
144141
(Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?)
145142
}

datafusion/functions/src/core/nvl2.rs

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::take_function_args;
1819
use arrow::array::Array;
1920
use arrow::compute::is_not_null;
2021
use arrow::compute::kernels::zip::zip;
2122
use arrow::datatypes::DataType;
22-
use datafusion_common::{exec_err, internal_err, Result};
23+
use datafusion_common::{internal_err, Result};
2324
use datafusion_expr::{
2425
type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
2526
ScalarUDFImpl, Signature, Volatility,
@@ -104,27 +105,22 @@ impl ScalarUDFImpl for NVL2Func {
104105
}
105106

106107
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
107-
if arg_types.len() != 3 {
108-
return exec_err!(
109-
"NVL2 takes exactly three arguments, but got {}",
110-
arg_types.len()
111-
);
112-
}
113-
let new_type = arg_types.iter().skip(1).try_fold(
114-
arg_types.first().unwrap().clone(),
115-
|acc, x| {
116-
// The coerced types found by `comparison_coercion` are not guaranteed to be
117-
// coercible for the arguments. `comparison_coercion` returns more loose
118-
// types that can be coerced to both `acc` and `x` for comparison purpose.
119-
// See `maybe_data_types` for the actual coercion.
120-
let coerced_type = comparison_coercion(&acc, x);
121-
if let Some(coerced_type) = coerced_type {
122-
Ok(coerced_type)
123-
} else {
124-
internal_err!("Coercion from {acc:?} to {x:?} failed.")
125-
}
126-
},
127-
)?;
108+
let [tested, if_non_null, if_null] = take_function_args(self.name(), arg_types)?;
109+
let new_type =
110+
[if_non_null, if_null]
111+
.iter()
112+
.try_fold(tested.clone(), |acc, x| {
113+
// The coerced types found by `comparison_coercion` are not guaranteed to be
114+
// coercible for the arguments. `comparison_coercion` returns more loose
115+
// types that can be coerced to both `acc` and `x` for comparison purpose.
116+
// See `maybe_data_types` for the actual coercion.
117+
let coerced_type = comparison_coercion(&acc, x);
118+
if let Some(coerced_type) = coerced_type {
119+
Ok(coerced_type)
120+
} else {
121+
internal_err!("Coercion from {acc:?} to {x:?} failed.")
122+
}
123+
})?;
128124
Ok(vec![new_type; arg_types.len()])
129125
}
130126

@@ -134,12 +130,6 @@ impl ScalarUDFImpl for NVL2Func {
134130
}
135131

136132
fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
137-
if args.len() != 3 {
138-
return internal_err!(
139-
"{:?} args were supplied but NVL2 takes exactly three args",
140-
args.len()
141-
);
142-
}
143133
let mut len = 1;
144134
let mut is_array = false;
145135
for arg in args {
@@ -157,20 +147,22 @@ fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
157147
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
158148
})
159149
.collect::<Result<Vec<_>>>()?;
160-
let to_apply = is_not_null(&args[0])?;
161-
let value = zip(&to_apply, &args[1], &args[2])?;
150+
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
151+
let to_apply = is_not_null(&tested)?;
152+
let value = zip(&to_apply, &if_non_null, &if_null)?;
162153
Ok(ColumnarValue::Array(value))
163154
} else {
164-
let mut current_value = &args[1];
165-
match &args[0] {
155+
let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
156+
match &tested {
166157
ColumnarValue::Array(_) => {
167158
internal_err!("except Scalar value, but got Array")
168159
}
169160
ColumnarValue::Scalar(scalar) => {
170161
if scalar.is_null() {
171-
current_value = &args[2];
162+
Ok(if_null.clone())
163+
} else {
164+
Ok(if_non_null.clone())
172165
}
173-
Ok(current_value.clone())
174166
}
175167
}
176168
}

datafusion/functions/src/core/version.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717

1818
//! [`VersionFunc`]: Implementation of the `version` function.
1919
20+
use crate::utils::take_function_args;
2021
use arrow::datatypes::DataType;
21-
use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
22+
use datafusion_common::{Result, ScalarValue};
2223
use datafusion_expr::{
2324
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
2425
};
2526
use datafusion_macros::user_doc;
2627
use std::any::Any;
28+
2729
#[user_doc(
2830
doc_section(label = "Other Functions"),
2931
description = "Returns the version of DataFusion.",
@@ -70,21 +72,16 @@ impl ScalarUDFImpl for VersionFunc {
7072
}
7173

7274
fn return_type(&self, args: &[DataType]) -> Result<DataType> {
73-
if args.is_empty() {
74-
Ok(DataType::Utf8)
75-
} else {
76-
plan_err!("version expects no arguments")
77-
}
75+
let [] = take_function_args(self.name(), args)?;
76+
Ok(DataType::Utf8)
7877
}
7978

8079
fn invoke_batch(
8180
&self,
8281
args: &[ColumnarValue],
8382
_number_rows: usize,
8483
) -> Result<ColumnarValue> {
85-
if !args.is_empty() {
86-
return internal_err!("{} function does not accept arguments", self.name());
87-
}
84+
let [] = take_function_args(self.name(), args)?;
8885
// TODO it would be great to add rust version and arrow version,
8986
// but that requires a `build.rs` script and/or adding a version const to arrow-rs
9087
let version = format!(

datafusion/functions/src/crypto/basic.rs

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use blake2::{Blake2b512, Blake2s256, Digest};
2424
use blake3::Hasher as Blake3;
2525
use datafusion_common::cast::as_binary_array;
2626

27+
use crate::utils::take_function_args;
2728
use arrow::compute::StringArrayType;
2829
use datafusion_common::plan_err;
2930
use datafusion_common::{
@@ -41,14 +42,8 @@ macro_rules! define_digest_function {
4142
($NAME: ident, $METHOD: ident, $DOC: expr) => {
4243
#[doc = $DOC]
4344
pub fn $NAME(args: &[ColumnarValue]) -> Result<ColumnarValue> {
44-
if args.len() != 1 {
45-
return exec_err!(
46-
"{:?} args were supplied but {} takes exactly one argument",
47-
args.len(),
48-
DigestAlgorithm::$METHOD.to_string()
49-
);
50-
}
51-
digest_process(&args[0], DigestAlgorithm::$METHOD)
45+
let [data] = take_function_args(&DigestAlgorithm::$METHOD.to_string(), args)?;
46+
digest_process(data, DigestAlgorithm::$METHOD)
5247
}
5348
};
5449
}
@@ -114,13 +109,8 @@ pub enum DigestAlgorithm {
114109
/// Second argument is the algorithm to use.
115110
/// Standard algorithms are md5, sha1, sha224, sha256, sha384 and sha512.
116111
pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
117-
if args.len() != 2 {
118-
return exec_err!(
119-
"{:?} args were supplied but digest takes exactly two arguments",
120-
args.len()
121-
);
122-
}
123-
let digest_algorithm = match &args[1] {
112+
let [data, digest_algorithm] = take_function_args("digest", args)?;
113+
let digest_algorithm = match digest_algorithm {
124114
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
125115
Some(Some(method)) => method.parse::<DigestAlgorithm>(),
126116
_ => exec_err!("Unsupported data type {scalar:?} for function digest"),
@@ -129,7 +119,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
129119
internal_err!("Digest using dynamically decided method is not yet supported")
130120
}
131121
}?;
132-
digest_process(&args[0], digest_algorithm)
122+
digest_process(data, digest_algorithm)
133123
}
134124

135125
impl FromStr for DigestAlgorithm {
@@ -175,15 +165,8 @@ impl fmt::Display for DigestAlgorithm {
175165

176166
/// computes md5 hash digest of the given input
177167
pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
178-
if args.len() != 1 {
179-
return exec_err!(
180-
"{:?} args were supplied but {} takes exactly one argument",
181-
args.len(),
182-
DigestAlgorithm::Md5
183-
);
184-
}
185-
186-
let value = digest_process(&args[0], DigestAlgorithm::Md5)?;
168+
let [data] = take_function_args("md5", args)?;
169+
let value = digest_process(data, DigestAlgorithm::Md5)?;
187170

188171
// md5 requires special handling because of its unique utf8 return type
189172
Ok(match value {

datafusion/functions/src/datetime/date_part.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use arrow::datatypes::DataType::{
2828
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
2929
use arrow::datatypes::{DataType, TimeUnit};
3030

31+
use crate::utils::take_function_args;
3132
use datafusion_common::not_impl_err;
3233
use datafusion_common::{
3334
cast::{
@@ -140,10 +141,9 @@ impl ScalarUDFImpl for DatePartFunc {
140141
}
141142

142143
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
143-
// Length check handled in the signature
144-
debug_assert_eq!(args.scalar_arguments.len(), 2);
144+
let [field, _] = take_function_args(self.name(), args.scalar_arguments)?;
145145

146-
args.scalar_arguments[0]
146+
field
147147
.and_then(|sv| {
148148
sv.try_as_str()
149149
.flatten()

0 commit comments

Comments
 (0)