18
18
use arrow:: array:: Array ;
19
19
use arrow:: datatypes:: { DataType , FieldRef , UnionFields } ;
20
20
use datafusion_common:: cast:: as_union_array;
21
+ use datafusion_common:: utils:: take_function_args;
21
22
use datafusion_common:: {
22
23
exec_datafusion_err, exec_err, internal_err, Result , ScalarValue ,
23
24
} ;
@@ -113,22 +114,15 @@ impl ScalarUDFImpl for UnionExtractFun {
113
114
}
114
115
115
116
fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
116
- let args = args. args ;
117
+ let [ array , target_name ] = take_function_args ( "union_extract" , args. args ) ? ;
117
118
118
- if args. len ( ) != 2 {
119
- return exec_err ! (
120
- "union_extract expects 2 arguments, got {} instead" ,
121
- args. len( )
122
- ) ;
123
- }
124
-
125
- let target_name = match & args[ 1 ] {
119
+ let target_name = match target_name {
126
120
ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( target_name) ) ) => Ok ( target_name) ,
127
121
ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) => exec_err ! ( "union_extract second argument must be a non-null string literal, got a null instead" ) ,
128
- _ => exec_err ! ( "union_extract second argument must be a non-null string literal, got {} instead" , & args [ 1 ] . data_type( ) ) ,
129
- } ;
122
+ _ => exec_err ! ( "union_extract second argument must be a non-null string literal, got {} instead" , target_name . data_type( ) ) ,
123
+ } ? ;
130
124
131
- match & args [ 0 ] {
125
+ match array {
132
126
ColumnarValue :: Array ( array) => {
133
127
let union_array = as_union_array ( & array) . map_err ( |_| {
134
128
exec_datafusion_err ! (
@@ -140,19 +134,16 @@ impl ScalarUDFImpl for UnionExtractFun {
140
134
Ok ( ColumnarValue :: Array (
141
135
arrow:: compute:: kernels:: union_extract:: union_extract (
142
136
union_array,
143
- target_name? ,
137
+ & target_name,
144
138
) ?,
145
139
) )
146
140
}
147
141
ColumnarValue :: Scalar ( ScalarValue :: Union ( value, fields, _) ) => {
148
- let target_name = target_name?;
149
- let ( target_type_id, target) = find_field ( fields, target_name) ?;
142
+ let ( target_type_id, target) = find_field ( & fields, & target_name) ?;
150
143
151
144
let result = match value {
152
- Some ( ( type_id, value) ) if target_type_id == * type_id => {
153
- * value. clone ( )
154
- }
155
- _ => ScalarValue :: try_from ( target. data_type ( ) ) ?,
145
+ Some ( ( type_id, value) ) if target_type_id == type_id => * value,
146
+ _ => ScalarValue :: try_new_null ( target. data_type ( ) ) ?,
156
147
} ;
157
148
158
149
Ok ( ColumnarValue :: Scalar ( result) )
0 commit comments