@@ -19,18 +19,18 @@ use std::any::Any;
19
19
use std:: cmp:: max;
20
20
use std:: sync:: Arc ;
21
21
22
+ use crate :: utils:: { make_scalar_function, utf8_to_str_type} ;
22
23
use arrow:: array:: {
23
- ArrayAccessor , ArrayIter , ArrayRef , AsArray , GenericStringArray , OffsetSizeTrait ,
24
+ make_view, Array , ArrayAccessor , ArrayIter , ArrayRef , AsArray , ByteView ,
25
+ GenericStringArray , OffsetSizeTrait , StringViewArray ,
24
26
} ;
25
27
use arrow:: datatypes:: DataType ;
26
-
28
+ use arrow_buffer :: { NullBufferBuilder , ScalarBuffer } ;
27
29
use datafusion_common:: cast:: as_int64_array;
28
30
use datafusion_common:: { exec_datafusion_err, exec_err, Result } ;
29
31
use datafusion_expr:: TypeSignature :: Exact ;
30
32
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl , Signature , Volatility } ;
31
33
32
- use crate :: utils:: { make_scalar_function, utf8_to_str_type} ;
33
-
34
34
#[ derive( Debug ) ]
35
35
pub struct SubstrFunc {
36
36
signature : Signature ,
@@ -77,7 +77,11 @@ impl ScalarUDFImpl for SubstrFunc {
77
77
}
78
78
79
79
fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
80
- utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
80
+ if arg_types[ 0 ] == DataType :: Utf8View {
81
+ Ok ( DataType :: Utf8View )
82
+ } else {
83
+ utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
84
+ }
81
85
}
82
86
83
87
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
@@ -89,29 +93,188 @@ impl ScalarUDFImpl for SubstrFunc {
89
93
}
90
94
}
91
95
96
+ /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
97
+ /// substr('alphabet', 3) = 'phabet'
98
+ /// substr('alphabet', 3, 2) = 'ph'
99
+ /// The implementation uses UTF-8 code points as characters
92
100
pub fn substr ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
93
101
match args[ 0 ] . data_type ( ) {
94
102
DataType :: Utf8 => {
95
103
let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
96
- calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
104
+ string_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
97
105
}
98
106
DataType :: LargeUtf8 => {
99
107
let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
100
- calculate_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
108
+ string_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
101
109
}
102
110
DataType :: Utf8View => {
103
111
let string_array = args[ 0 ] . as_string_view ( ) ;
104
- calculate_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
112
+ string_view_substr ( string_array, & args[ 1 ..] )
105
113
}
106
- other => exec_err ! ( "Unsupported data type {other:?} for function substr" ) ,
114
+ other => exec_err ! (
115
+ "Unsupported data type {other:?} for function substr,\
116
+ expected Utf8View, Utf8 or LargeUtf8."
117
+ ) ,
107
118
}
108
119
}
109
120
110
- /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
111
- /// substr('alphabet', 3) = 'phabet'
112
- /// substr('alphabet', 3, 2) = 'ph'
113
- /// The implementation uses UTF-8 code points as characters
114
- fn calculate_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
121
+ // Return the exact byte index for [start, end), set count to -1 to ignore count
122
+ fn get_true_start_end ( input : & str , start : usize , count : i64 ) -> ( usize , usize ) {
123
+ let ( mut st, mut ed) = ( input. len ( ) , input. len ( ) ) ;
124
+ let mut start_counting = false ;
125
+ let mut cnt = 0 ;
126
+ for ( char_cnt, ( byte_cnt, _) ) in input. char_indices ( ) . enumerate ( ) {
127
+ if char_cnt == start {
128
+ st = byte_cnt;
129
+ if count != -1 {
130
+ start_counting = true ;
131
+ } else {
132
+ break ;
133
+ }
134
+ }
135
+ if start_counting {
136
+ if cnt == count {
137
+ ed = byte_cnt;
138
+ break ;
139
+ }
140
+ cnt += 1 ;
141
+ }
142
+ }
143
+ ( st, ed)
144
+ }
145
+
146
+ /// Make a `u128` based on the given substr, start(offset to view.offset), and
147
+ /// push into to the given buffers
148
+ fn make_and_append_view (
149
+ views_buffer : & mut Vec < u128 > ,
150
+ null_builder : & mut NullBufferBuilder ,
151
+ raw : & u128 ,
152
+ substr : & str ,
153
+ start : u32 ,
154
+ ) {
155
+ let substr_len = substr. len ( ) ;
156
+ if substr_len == 0 {
157
+ null_builder. append_null ( ) ;
158
+ views_buffer. push ( 0 ) ;
159
+ } else {
160
+ let sub_view = if substr_len > 12 {
161
+ let view = ByteView :: from ( * raw) ;
162
+ make_view ( substr. as_bytes ( ) , view. buffer_index , view. offset + start)
163
+ } else {
164
+ // inline value does not need block id or offset
165
+ make_view ( substr. as_bytes ( ) , 0 , 0 )
166
+ } ;
167
+ views_buffer. push ( sub_view) ;
168
+ null_builder. append_non_null ( ) ;
169
+ }
170
+ }
171
+
172
+ // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
173
+ // From<u128> for ByteView
174
+ fn string_view_substr (
175
+ string_view_array : & StringViewArray ,
176
+ args : & [ ArrayRef ] ,
177
+ ) -> Result < ArrayRef > {
178
+ let mut views_buf = Vec :: with_capacity ( string_view_array. len ( ) ) ;
179
+ let mut null_builder = NullBufferBuilder :: new ( string_view_array. len ( ) ) ;
180
+
181
+ let start_array = as_int64_array ( & args[ 0 ] ) ?;
182
+
183
+ match args. len ( ) {
184
+ 1 => {
185
+ for ( idx, ( raw, start) ) in string_view_array
186
+ . views ( )
187
+ . iter ( )
188
+ . zip ( start_array. iter ( ) )
189
+ . enumerate ( )
190
+ {
191
+ if let Some ( start) = start {
192
+ let start = ( start - 1 ) . max ( 0 ) as usize ;
193
+
194
+ // Safety:
195
+ // idx is always smaller or equal to string_view_array.views.len()
196
+ unsafe {
197
+ let str = string_view_array. value_unchecked ( idx) ;
198
+ let ( start, end) = get_true_start_end ( str, start, -1 ) ;
199
+ let substr = & str[ start..end] ;
200
+
201
+ make_and_append_view (
202
+ & mut views_buf,
203
+ & mut null_builder,
204
+ raw,
205
+ substr,
206
+ start as u32 ,
207
+ ) ;
208
+ }
209
+ } else {
210
+ null_builder. append_null ( ) ;
211
+ views_buf. push ( 0 ) ;
212
+ }
213
+ }
214
+ }
215
+ 2 => {
216
+ let count_array = as_int64_array ( & args[ 1 ] ) ?;
217
+ for ( idx, ( ( raw, start) , count) ) in string_view_array
218
+ . views ( )
219
+ . iter ( )
220
+ . zip ( start_array. iter ( ) )
221
+ . zip ( count_array. iter ( ) )
222
+ . enumerate ( )
223
+ {
224
+ if let ( Some ( start) , Some ( count) ) = ( start, count) {
225
+ let start = ( start - 1 ) . max ( 0 ) as usize ;
226
+ if count < 0 {
227
+ return exec_err ! (
228
+ "negative substring length not allowed: substr(<str>, {start}, {count})"
229
+ ) ;
230
+ } else {
231
+ // Safety:
232
+ // idx is always smaller or equal to string_view_array.views.len()
233
+ unsafe {
234
+ let str = string_view_array. value_unchecked ( idx) ;
235
+ let ( start, end) = get_true_start_end ( str, start, count) ;
236
+ let substr = & str[ start..end] ;
237
+
238
+ make_and_append_view (
239
+ & mut views_buf,
240
+ & mut null_builder,
241
+ raw,
242
+ substr,
243
+ start as u32 ,
244
+ ) ;
245
+ }
246
+ }
247
+ } else {
248
+ null_builder. append_null ( ) ;
249
+ views_buf. push ( 0 ) ;
250
+ }
251
+ }
252
+ }
253
+ other => {
254
+ return exec_err ! (
255
+ "substr was called with {other} arguments. It requires 2 or 3."
256
+ )
257
+ }
258
+ }
259
+
260
+ let views_buf = ScalarBuffer :: from ( views_buf) ;
261
+ let nulls_buf = null_builder. finish ( ) ;
262
+
263
+ // Safety:
264
+ // (1) The blocks of the given views are all provided
265
+ // (2) Each of the range `view.offset+start..end` of view in views_buf is within
266
+ // the bounds of each of the blocks
267
+ unsafe {
268
+ let array = StringViewArray :: new_unchecked (
269
+ views_buf,
270
+ string_view_array. data_buffers ( ) . to_vec ( ) ,
271
+ nulls_buf,
272
+ ) ;
273
+ Ok ( Arc :: new ( array) as ArrayRef )
274
+ }
275
+ }
276
+
277
+ fn string_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
115
278
where
116
279
V : ArrayAccessor < Item = & ' a str > ,
117
280
T : OffsetSizeTrait ,
@@ -174,8 +337,8 @@ where
174
337
175
338
#[ cfg( test) ]
176
339
mod tests {
177
- use arrow:: array:: { Array , StringArray } ;
178
- use arrow:: datatypes:: DataType :: Utf8 ;
340
+ use arrow:: array:: { Array , StringArray , StringViewArray } ;
341
+ use arrow:: datatypes:: DataType :: { Utf8 , Utf8View } ;
179
342
180
343
use datafusion_common:: { exec_err, Result , ScalarValue } ;
181
344
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
@@ -193,8 +356,8 @@ mod tests {
193
356
] ,
194
357
Ok ( None ) ,
195
358
& str ,
196
- Utf8 ,
197
- StringArray
359
+ Utf8View ,
360
+ StringViewArray
198
361
) ;
199
362
test_function ! (
200
363
SubstrFunc :: new( ) ,
@@ -206,8 +369,35 @@ mod tests {
206
369
] ,
207
370
Ok ( Some ( "alphabet" ) ) ,
208
371
& str ,
209
- Utf8 ,
210
- StringArray
372
+ Utf8View ,
373
+ StringViewArray
374
+ ) ;
375
+ test_function ! (
376
+ SubstrFunc :: new( ) ,
377
+ & [
378
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
379
+ "this és longer than 12B"
380
+ ) ) ) ) ,
381
+ ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
382
+ ColumnarValue :: Scalar ( ScalarValue :: from( 2i64 ) ) ,
383
+ ] ,
384
+ Ok ( Some ( " é" ) ) ,
385
+ & str ,
386
+ Utf8View ,
387
+ StringViewArray
388
+ ) ;
389
+ test_function ! (
390
+ SubstrFunc :: new( ) ,
391
+ & [
392
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from(
393
+ "this is longer than 12B"
394
+ ) ) ) ) ,
395
+ ColumnarValue :: Scalar ( ScalarValue :: from( 5i64 ) ) ,
396
+ ] ,
397
+ Ok ( Some ( " is longer than 12B" ) ) ,
398
+ & str ,
399
+ Utf8View ,
400
+ StringViewArray
211
401
) ;
212
402
test_function ! (
213
403
SubstrFunc :: new( ) ,
@@ -219,8 +409,8 @@ mod tests {
219
409
] ,
220
410
Ok ( Some ( "ésoj" ) ) ,
221
411
& str ,
222
- Utf8 ,
223
- StringArray
412
+ Utf8View ,
413
+ StringViewArray
224
414
) ;
225
415
test_function ! (
226
416
SubstrFunc :: new( ) ,
@@ -233,8 +423,8 @@ mod tests {
233
423
] ,
234
424
Ok ( Some ( "ph" ) ) ,
235
425
& str ,
236
- Utf8 ,
237
- StringArray
426
+ Utf8View ,
427
+ StringViewArray
238
428
) ;
239
429
test_function ! (
240
430
SubstrFunc :: new( ) ,
@@ -247,8 +437,8 @@ mod tests {
247
437
] ,
248
438
Ok ( Some ( "phabet" ) ) ,
249
439
& str ,
250
- Utf8 ,
251
- StringArray
440
+ Utf8View ,
441
+ StringViewArray
252
442
) ;
253
443
test_function ! (
254
444
SubstrFunc :: new( ) ,
0 commit comments