@@ -19,11 +19,13 @@ use std::any::Any;
19
19
use std:: cmp:: max;
20
20
use std:: sync:: Arc ;
21
21
22
- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
22
+ use arrow:: array:: {
23
+ ArrayAccessor , ArrayIter , ArrayRef , AsArray , GenericStringArray , OffsetSizeTrait ,
24
+ } ;
23
25
use arrow:: datatypes:: DataType ;
24
26
25
- use datafusion_common:: cast:: { as_generic_string_array , as_int64_array} ;
26
- use datafusion_common:: { exec_err , Result } ;
27
+ use datafusion_common:: cast:: as_int64_array;
28
+ use datafusion_common:: { DataFusionError , Result } ;
27
29
use datafusion_expr:: TypeSignature :: Exact ;
28
30
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl , Signature , Volatility } ;
29
31
@@ -51,6 +53,8 @@ impl SubstrFunc {
51
53
Exact ( vec![ LargeUtf8 , Int64 ] ) ,
52
54
Exact ( vec![ Utf8 , Int64 , Int64 ] ) ,
53
55
Exact ( vec![ LargeUtf8 , Int64 , Int64 ] ) ,
56
+ Exact ( vec![ Utf8View , Int64 ] ) ,
57
+ Exact ( vec![ Utf8View , Int64 , Int64 ] ) ,
54
58
] ,
55
59
Volatility :: Immutable ,
56
60
) ,
@@ -77,30 +81,49 @@ impl ScalarUDFImpl for SubstrFunc {
77
81
}
78
82
79
83
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
80
- match args[ 0 ] . data_type ( ) {
81
- DataType :: Utf8 => make_scalar_function ( substr :: < i32 > , vec ! [ ] ) ( args) ,
82
- DataType :: LargeUtf8 => make_scalar_function ( substr :: < i64 > , vec ! [ ] ) ( args) ,
83
- other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
84
- }
84
+ make_scalar_function ( substr, vec ! [ ] ) ( args)
85
85
}
86
86
87
87
fn aliases ( & self ) -> & [ String ] {
88
88
& self . aliases
89
89
}
90
90
}
91
91
92
+ pub fn substr ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
93
+ match args[ 0 ] . data_type ( ) {
94
+ DataType :: Utf8 => {
95
+ let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
96
+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
97
+ }
98
+ DataType :: LargeUtf8 => {
99
+ let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
100
+ calculate_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
101
+ }
102
+ DataType :: Utf8View => {
103
+ let string_array = args[ 0 ] . as_string_view ( ) ;
104
+ calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
105
+ }
106
+ _ => Err ( DataFusionError :: Internal (
107
+ "Unsupported data type for function substr" . to_string ( ) ,
108
+ ) ) ,
109
+ }
110
+ }
111
+
92
112
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
93
113
/// substr('alphabet', 3) = 'phabet'
94
114
/// substr('alphabet', 3, 2) = 'ph'
95
115
/// The implementation uses UTF-8 code points as characters
96
- pub fn substr < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
116
+ fn calculate_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
117
+ where
118
+ V : ArrayAccessor < Item = & ' a str > ,
119
+ T : OffsetSizeTrait ,
120
+ {
97
121
match args. len ( ) {
98
- 2 => {
99
- let string_array = as_generic_string_array :: < T > ( & args [ 0 ] ) ? ;
100
- let start_array = as_int64_array ( & args[ 1 ] ) ?;
122
+ 1 => {
123
+ let iter = ArrayIter :: new ( string_array ) ;
124
+ let start_array = as_int64_array ( & args[ 0 ] ) ?;
101
125
102
- let result = string_array
103
- . iter ( )
126
+ let result = iter
104
127
. zip ( start_array. iter ( ) )
105
128
. map ( |( string, start) | match ( string, start) {
106
129
( Some ( string) , Some ( start) ) => {
@@ -113,24 +136,23 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
113
136
_ => None ,
114
137
} )
115
138
. collect :: < GenericStringArray < T > > ( ) ;
116
-
117
139
Ok ( Arc :: new ( result) as ArrayRef )
118
140
}
119
- 3 => {
120
- let string_array = as_generic_string_array :: < T > ( & args [ 0 ] ) ? ;
121
- let start_array = as_int64_array ( & args[ 1 ] ) ?;
122
- let count_array = as_int64_array ( & args[ 2 ] ) ?;
141
+ 2 => {
142
+ let iter = ArrayIter :: new ( string_array ) ;
143
+ let start_array = as_int64_array ( & args[ 0 ] ) ?;
144
+ let count_array = as_int64_array ( & args[ 1 ] ) ?;
123
145
124
- let result = string_array
125
- . iter ( )
146
+ let result = iter
126
147
. zip ( start_array. iter ( ) )
127
148
. zip ( count_array. iter ( ) )
128
149
. map ( |( ( string, start) , count) | match ( string, start, count) {
129
150
( Some ( string) , Some ( start) , Some ( count) ) => {
130
151
if count < 0 {
131
- exec_err ! (
132
- "negative substring length not allowed: substr(<str>, {start}, {count})"
133
- )
152
+ Err ( DataFusionError :: Execution ( format ! (
153
+ "negative substring length not allowed: substr(<str>, {}, {})" ,
154
+ start, count
155
+ ) ) )
134
156
} else {
135
157
let skip = max ( 0 , start - 1 ) ;
136
158
let count = max ( 0 , count + ( if start < 1 { start - 1 } else { 0 } ) ) ;
@@ -143,9 +165,10 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
143
165
144
166
Ok ( Arc :: new ( result) as ArrayRef )
145
167
}
146
- other => {
147
- exec_err ! ( "substr was called with {other} arguments. It requires 2 or 3." )
148
- }
168
+ _ => Err ( DataFusionError :: Execution ( format ! (
169
+ "substr was called with {} arguments. It requires 2 or 3." ,
170
+ args. len( )
171
+ ) ) ) ,
149
172
}
150
173
}
151
174
@@ -162,6 +185,71 @@ mod tests {
162
185
163
186
#[ test]
164
187
fn test_functions ( ) -> Result < ( ) > {
188
+ test_function ! (
189
+ SubstrFunc :: new( ) ,
190
+ & [
191
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( None ) ) ,
192
+ ColumnarValue :: Scalar ( ScalarValue :: from( 1i64 ) ) ,
193
+ ] ,
194
+ Ok ( None ) ,
195
+ & str ,
196
+ Utf8 ,
197
+ StringArray
198
+ ) ;
199
+ test_function ! (
200
+ SubstrFunc :: new( ) ,
201
+ & [
202
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
203
+ "alphabet"
204
+ ) ) ) ) ,
205
+ ColumnarValue :: Scalar ( ScalarValue :: from( 0i64 ) ) ,
206
+ ] ,
207
+ Ok ( Some ( "alphabet" ) ) ,
208
+ & str ,
209
+ Utf8 ,
210
+ StringArray
211
+ ) ;
212
+ test_function ! (
213
+ SubstrFunc :: new( ) ,
214
+ & [
215
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
216
+ "joséésoj"
217
+ ) ) ) ) ,
218
+ ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
219
+ ] ,
220
+ Ok ( Some ( "ésoj" ) ) ,
221
+ & str ,
222
+ Utf8 ,
223
+ StringArray
224
+ ) ;
225
+ test_function ! (
226
+ SubstrFunc :: new( ) ,
227
+ & [
228
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
229
+ "alphabet"
230
+ ) ) ) ) ,
231
+ ColumnarValue :: Scalar ( ScalarValue :: from( 3i64 ) ) ,
232
+ ColumnarValue :: Scalar ( ScalarValue :: from( 2i64 ) ) ,
233
+ ] ,
234
+ Ok ( Some ( "ph" ) ) ,
235
+ & str ,
236
+ Utf8 ,
237
+ StringArray
238
+ ) ;
239
+ test_function ! (
240
+ SubstrFunc :: new( ) ,
241
+ & [
242
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
243
+ "alphabet"
244
+ ) ) ) ) ,
245
+ ColumnarValue :: Scalar ( ScalarValue :: from( 3i64 ) ) ,
246
+ ColumnarValue :: Scalar ( ScalarValue :: from( 20i64 ) ) ,
247
+ ] ,
248
+ Ok ( Some ( "phabet" ) ) ,
249
+ & str ,
250
+ Utf8 ,
251
+ StringArray
252
+ ) ;
165
253
test_function ! (
166
254
SubstrFunc :: new( ) ,
167
255
& [
0 commit comments