@@ -21,7 +21,9 @@ use std::sync::Arc;
21
21
use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
22
22
use arrow:: datatypes:: DataType ;
23
23
24
- use datafusion_common:: cast:: { as_generic_string_array, as_int64_array} ;
24
+ use datafusion_common:: cast:: {
25
+ as_generic_string_array, as_int64_array, as_string_view_array,
26
+ } ;
25
27
use datafusion_common:: { exec_err, Result } ;
26
28
use datafusion_expr:: TypeSignature :: * ;
27
29
use datafusion_expr:: { ColumnarValue , Volatility } ;
@@ -46,8 +48,10 @@ impl OverlayFunc {
46
48
Self {
47
49
signature : Signature :: one_of (
48
50
vec ! [
51
+ Exact ( vec![ Utf8View , Utf8View , Int64 , Int64 ] ) ,
49
52
Exact ( vec![ Utf8 , Utf8 , Int64 , Int64 ] ) ,
50
53
Exact ( vec![ LargeUtf8 , LargeUtf8 , Int64 , Int64 ] ) ,
54
+ Exact ( vec![ Utf8View , Utf8View , Int64 ] ) ,
51
55
Exact ( vec![ Utf8 , Utf8 , Int64 ] ) ,
52
56
Exact ( vec![ LargeUtf8 , LargeUtf8 , Int64 ] ) ,
53
57
] ,
@@ -76,54 +80,107 @@ impl ScalarUDFImpl for OverlayFunc {
76
80
77
81
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
78
82
match args[ 0 ] . data_type ( ) {
79
- DataType :: Utf8 => make_scalar_function ( overlay :: < i32 > , vec ! [ ] ) ( args) ,
83
+ DataType :: Utf8View | DataType :: Utf8 => {
84
+ make_scalar_function ( overlay :: < i32 > , vec ! [ ] ) ( args)
85
+ }
80
86
DataType :: LargeUtf8 => make_scalar_function ( overlay :: < i64 > , vec ! [ ] ) ( args) ,
81
87
other => exec_err ! ( "Unsupported data type {other:?} for function overlay" ) ,
82
88
}
83
89
}
84
90
}
85
91
92
+ macro_rules! process_overlay {
93
+ // For the three-argument case
94
+ ( $string_array: expr, $characters_array: expr, $pos_num: expr) => { {
95
+ $string_array
96
+ . iter( )
97
+ . zip( $characters_array. iter( ) )
98
+ . zip( $pos_num. iter( ) )
99
+ . map( |( ( string, characters) , start_pos) | {
100
+ match ( string, characters, start_pos) {
101
+ ( Some ( string) , Some ( characters) , Some ( start_pos) ) => {
102
+ let string_len = string. chars( ) . count( ) ;
103
+ let characters_len = characters. chars( ) . count( ) ;
104
+ let replace_len = characters_len as i64 ;
105
+ let mut res =
106
+ String :: with_capacity( string_len. max( characters_len) ) ;
107
+
108
+ //as sql replace index start from 1 while string index start from 0
109
+ if start_pos > 1 && start_pos - 1 < string_len as i64 {
110
+ let start = ( start_pos - 1 ) as usize ;
111
+ res. push_str( & string[ ..start] ) ;
112
+ }
113
+ res. push_str( characters) ;
114
+ // if start + replace_len - 1 >= string_length, just to string end
115
+ if start_pos + replace_len - 1 < string_len as i64 {
116
+ let end = ( start_pos + replace_len - 1 ) as usize ;
117
+ res. push_str( & string[ end..] ) ;
118
+ }
119
+ Ok ( Some ( res) )
120
+ }
121
+ _ => Ok ( None ) ,
122
+ }
123
+ } )
124
+ . collect:: <Result <GenericStringArray <T >>>( )
125
+ } } ;
126
+
127
+ // For the four-argument case
128
+ ( $string_array: expr, $characters_array: expr, $pos_num: expr, $len_num: expr) => { {
129
+ $string_array
130
+ . iter( )
131
+ . zip( $characters_array. iter( ) )
132
+ . zip( $pos_num. iter( ) )
133
+ . zip( $len_num. iter( ) )
134
+ . map( |( ( ( string, characters) , start_pos) , len) | {
135
+ match ( string, characters, start_pos, len) {
136
+ ( Some ( string) , Some ( characters) , Some ( start_pos) , Some ( len) ) => {
137
+ let string_len = string. chars( ) . count( ) ;
138
+ let characters_len = characters. chars( ) . count( ) ;
139
+ let replace_len = len. min( string_len as i64 ) ;
140
+ let mut res =
141
+ String :: with_capacity( string_len. max( characters_len) ) ;
142
+
143
+ //as sql replace index start from 1 while string index start from 0
144
+ if start_pos > 1 && start_pos - 1 < string_len as i64 {
145
+ let start = ( start_pos - 1 ) as usize ;
146
+ res. push_str( & string[ ..start] ) ;
147
+ }
148
+ res. push_str( characters) ;
149
+ // if start + replace_len - 1 >= string_length, just to string end
150
+ if start_pos + replace_len - 1 < string_len as i64 {
151
+ let end = ( start_pos + replace_len - 1 ) as usize ;
152
+ res. push_str( & string[ end..] ) ;
153
+ }
154
+ Ok ( Some ( res) )
155
+ }
156
+ _ => Ok ( None ) ,
157
+ }
158
+ } )
159
+ . collect:: <Result <GenericStringArray <T >>>( )
160
+ } } ;
161
+ }
162
+
86
163
/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
87
164
/// Replaces a substring of string1 with string2 starting at the integer bit
88
165
/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
89
166
/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
90
- pub fn overlay < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
167
+ fn overlay < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
168
+ let use_string_view = args[ 0 ] . data_type ( ) == & DataType :: Utf8View ;
169
+ if use_string_view {
170
+ string_view_overlay :: < T > ( args)
171
+ } else {
172
+ string_overlay :: < T > ( args)
173
+ }
174
+ }
175
+
176
+ pub fn string_overlay < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
91
177
match args. len ( ) {
92
178
3 => {
93
179
let string_array = as_generic_string_array :: < T > ( & args[ 0 ] ) ?;
94
180
let characters_array = as_generic_string_array :: < T > ( & args[ 1 ] ) ?;
95
181
let pos_num = as_int64_array ( & args[ 2 ] ) ?;
96
182
97
- let result = string_array
98
- . iter ( )
99
- . zip ( characters_array. iter ( ) )
100
- . zip ( pos_num. iter ( ) )
101
- . map ( |( ( string, characters) , start_pos) | {
102
- match ( string, characters, start_pos) {
103
- ( Some ( string) , Some ( characters) , Some ( start_pos) ) => {
104
- let string_len = string. chars ( ) . count ( ) ;
105
- let characters_len = characters. chars ( ) . count ( ) ;
106
- let replace_len = characters_len as i64 ;
107
- let mut res =
108
- String :: with_capacity ( string_len. max ( characters_len) ) ;
109
-
110
- //as sql replace index start from 1 while string index start from 0
111
- if start_pos > 1 && start_pos - 1 < string_len as i64 {
112
- let start = ( start_pos - 1 ) as usize ;
113
- res. push_str ( & string[ ..start] ) ;
114
- }
115
- res. push_str ( characters) ;
116
- // if start + replace_len - 1 >= string_length, just to string end
117
- if start_pos + replace_len - 1 < string_len as i64 {
118
- let end = ( start_pos + replace_len - 1 ) as usize ;
119
- res. push_str ( & string[ end..] ) ;
120
- }
121
- Ok ( Some ( res) )
122
- }
123
- _ => Ok ( None ) ,
124
- }
125
- } )
126
- . collect :: < Result < GenericStringArray < T > > > ( ) ?;
183
+ let result = process_overlay ! ( string_array, characters_array, pos_num) ?;
127
184
Ok ( Arc :: new ( result) as ArrayRef )
128
185
}
129
186
4 => {
@@ -132,37 +189,34 @@ pub fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
132
189
let pos_num = as_int64_array ( & args[ 2 ] ) ?;
133
190
let len_num = as_int64_array ( & args[ 3 ] ) ?;
134
191
135
- let result = string_array
136
- . iter ( )
137
- . zip ( characters_array. iter ( ) )
138
- . zip ( pos_num. iter ( ) )
139
- . zip ( len_num. iter ( ) )
140
- . map ( |( ( ( string, characters) , start_pos) , len) | {
141
- match ( string, characters, start_pos, len) {
142
- ( Some ( string) , Some ( characters) , Some ( start_pos) , Some ( len) ) => {
143
- let string_len = string. chars ( ) . count ( ) ;
144
- let characters_len = characters. chars ( ) . count ( ) ;
145
- let replace_len = len. min ( string_len as i64 ) ;
146
- let mut res =
147
- String :: with_capacity ( string_len. max ( characters_len) ) ;
148
-
149
- //as sql replace index start from 1 while string index start from 0
150
- if start_pos > 1 && start_pos - 1 < string_len as i64 {
151
- let start = ( start_pos - 1 ) as usize ;
152
- res. push_str ( & string[ ..start] ) ;
153
- }
154
- res. push_str ( characters) ;
155
- // if start + replace_len - 1 >= string_length, just to string end
156
- if start_pos + replace_len - 1 < string_len as i64 {
157
- let end = ( start_pos + replace_len - 1 ) as usize ;
158
- res. push_str ( & string[ end..] ) ;
159
- }
160
- Ok ( Some ( res) )
161
- }
162
- _ => Ok ( None ) ,
163
- }
164
- } )
165
- . collect :: < Result < GenericStringArray < T > > > ( ) ?;
192
+ let result =
193
+ process_overlay ! ( string_array, characters_array, pos_num, len_num) ?;
194
+ Ok ( Arc :: new ( result) as ArrayRef )
195
+ }
196
+ other => {
197
+ exec_err ! ( "overlay was called with {other} arguments. It requires 3 or 4." )
198
+ }
199
+ }
200
+ }
201
+
202
+ pub fn string_view_overlay < T : OffsetSizeTrait > ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
203
+ match args. len ( ) {
204
+ 3 => {
205
+ let string_array = as_string_view_array ( & args[ 0 ] ) ?;
206
+ let characters_array = as_string_view_array ( & args[ 1 ] ) ?;
207
+ let pos_num = as_int64_array ( & args[ 2 ] ) ?;
208
+
209
+ let result = process_overlay ! ( string_array, characters_array, pos_num) ?;
210
+ Ok ( Arc :: new ( result) as ArrayRef )
211
+ }
212
+ 4 => {
213
+ let string_array = as_string_view_array ( & args[ 0 ] ) ?;
214
+ let characters_array = as_string_view_array ( & args[ 1 ] ) ?;
215
+ let pos_num = as_int64_array ( & args[ 2 ] ) ?;
216
+ let len_num = as_int64_array ( & args[ 3 ] ) ?;
217
+
218
+ let result =
219
+ process_overlay ! ( string_array, characters_array, pos_num, len_num) ?;
166
220
Ok ( Arc :: new ( result) as ArrayRef )
167
221
}
168
222
other => {
0 commit comments