Skip to content

Commit e4b78c7

Browse files
authored
minor: simplify union_extract code (#14640)
* minor: simplify `union_extract` code * Fix CI tests on main
1 parent 71f9d0c commit e4b78c7

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

datafusion/functions/src/core/union_extract.rs

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use arrow::array::Array;
1919
use arrow::datatypes::{DataType, FieldRef, UnionFields};
2020
use datafusion_common::cast::as_union_array;
21+
use datafusion_common::utils::take_function_args;
2122
use datafusion_common::{
2223
exec_datafusion_err, exec_err, internal_err, Result, ScalarValue,
2324
};
@@ -113,22 +114,15 @@ impl ScalarUDFImpl for UnionExtractFun {
113114
}
114115

115116
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
116-
let args = args.args;
117+
let [array, target_name] = take_function_args("union_extract", args.args)?;
117118

118-
if args.len() != 2 {
119-
return exec_err!(
120-
"union_extract expects 2 arguments, got {} instead",
121-
args.len()
122-
);
123-
}
124-
125-
let target_name = match &args[1] {
119+
let target_name = match target_name {
126120
ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name),
127121
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"),
128-
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()),
129-
};
122+
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", target_name.data_type()),
123+
}?;
130124

131-
match &args[0] {
125+
match array {
132126
ColumnarValue::Array(array) => {
133127
let union_array = as_union_array(&array).map_err(|_| {
134128
exec_datafusion_err!(
@@ -140,19 +134,16 @@ impl ScalarUDFImpl for UnionExtractFun {
140134
Ok(ColumnarValue::Array(
141135
arrow::compute::kernels::union_extract::union_extract(
142136
union_array,
143-
target_name?,
137+
&target_name,
144138
)?,
145139
))
146140
}
147141
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
148-
let target_name = target_name?;
149-
let (target_type_id, target) = find_field(fields, target_name)?;
142+
let (target_type_id, target) = find_field(&fields, &target_name)?;
150143

151144
let result = match value {
152-
Some((type_id, value)) if target_type_id == *type_id => {
153-
*value.clone()
154-
}
155-
_ => ScalarValue::try_from(target.data_type())?,
145+
Some((type_id, value)) if target_type_id == type_id => *value,
146+
_ => ScalarValue::try_new_null(target.data_type())?,
156147
};
157148

158149
Ok(ColumnarValue::Scalar(result))

0 commit comments

Comments
 (0)