18
18
use std:: any:: Any ;
19
19
use std:: sync:: Arc ;
20
20
21
- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
21
+ use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait , StringArray } ;
22
22
use arrow:: datatypes:: DataType ;
23
23
24
- use datafusion_common:: cast:: { as_generic_string_array, as_int64_array} ;
24
+ use datafusion_common:: cast:: {
25
+ as_generic_string_array, as_int64_array, as_string_view_array,
26
+ } ;
25
27
use datafusion_common:: { exec_err, Result } ;
26
28
use datafusion_expr:: TypeSignature :: * ;
27
29
use datafusion_expr:: { ColumnarValue , Volatility } ;
@@ -45,7 +47,14 @@ impl RepeatFunc {
45
47
use DataType :: * ;
46
48
Self {
47
49
signature : Signature :: one_of (
48
- vec ! [ Exact ( vec![ Utf8 , Int64 ] ) , Exact ( vec![ LargeUtf8 , Int64 ] ) ] ,
50
+ vec ! [
51
+ // Planner attempts coercion to the target type starting with the most preferred candidate.
52
+ // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`.
53
+ // If that fails, it proceeds to `(Utf8, Int64)`.
54
+ Exact ( vec![ Utf8View , Int64 ] ) ,
55
+ Exact ( vec![ Utf8 , Int64 ] ) ,
56
+ Exact ( vec![ LargeUtf8 , Int64 ] ) ,
57
+ ] ,
49
58
Volatility :: Immutable ,
50
59
) ,
51
60
}
@@ -71,9 +80,10 @@ impl ScalarUDFImpl for RepeatFunc {
71
80
72
81
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
73
82
match args[ 0 ] . data_type ( ) {
83
+ DataType :: Utf8View => make_scalar_function ( repeat_utf8view, vec ! [ ] ) ( args) ,
74
84
DataType :: Utf8 => make_scalar_function ( repeat :: < i32 > , vec ! [ ] ) ( args) ,
75
85
DataType :: LargeUtf8 => make_scalar_function ( repeat :: < i64 > , vec ! [ ] ) ( args) ,
76
- other => exec_err ! ( "Unsupported data type {other:?} for function repeat" ) ,
86
+ other => exec_err ! ( "Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8 " ) ,
77
87
}
78
88
}
79
89
}
@@ -87,18 +97,35 @@ fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
87
97
let result = string_array
88
98
. iter ( )
89
99
. zip ( number_array. iter ( ) )
90
- . map ( |( string, number) | match ( string, number) {
91
- ( Some ( string) , Some ( number) ) if number >= 0 => {
92
- Some ( string. repeat ( number as usize ) )
93
- }
94
- ( Some ( _) , Some ( _) ) => Some ( "" . to_string ( ) ) ,
95
- _ => None ,
96
- } )
100
+ . map ( |( string, number) | repeat_common ( string, number) )
97
101
. collect :: < GenericStringArray < T > > ( ) ;
98
102
99
103
Ok ( Arc :: new ( result) as ArrayRef )
100
104
}
101
105
106
+ fn repeat_utf8view ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
107
+ let string_view_array = as_string_view_array ( & args[ 0 ] ) ?;
108
+ let number_array = as_int64_array ( & args[ 1 ] ) ?;
109
+
110
+ let result = string_view_array
111
+ . iter ( )
112
+ . zip ( number_array. iter ( ) )
113
+ . map ( |( string, number) | repeat_common ( string, number) )
114
+ . collect :: < StringArray > ( ) ;
115
+
116
+ Ok ( Arc :: new ( result) as ArrayRef )
117
+ }
118
+
119
+ fn repeat_common ( string : Option < & str > , number : Option < i64 > ) -> Option < String > {
120
+ match ( string, number) {
121
+ ( Some ( string) , Some ( number) ) if number >= 0 => {
122
+ Some ( string. repeat ( number as usize ) )
123
+ }
124
+ ( Some ( _) , Some ( _) ) => Some ( "" . to_string ( ) ) ,
125
+ _ => None ,
126
+ }
127
+ }
128
+
102
129
#[ cfg( test) ]
103
130
mod tests {
104
131
use arrow:: array:: { Array , StringArray } ;
@@ -124,7 +151,6 @@ mod tests {
124
151
Utf8 ,
125
152
StringArray
126
153
) ;
127
-
128
154
test_function ! (
129
155
RepeatFunc :: new( ) ,
130
156
& [
@@ -148,6 +174,40 @@ mod tests {
148
174
StringArray
149
175
) ;
150
176
177
+ test_function ! (
178
+ RepeatFunc :: new( ) ,
179
+ & [
180
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "Pg" ) ) ) ) ,
181
+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 4 ) ) ) ,
182
+ ] ,
183
+ Ok ( Some ( "PgPgPgPg" ) ) ,
184
+ & str ,
185
+ Utf8 ,
186
+ StringArray
187
+ ) ;
188
+ test_function ! (
189
+ RepeatFunc :: new( ) ,
190
+ & [
191
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( None ) ) ,
192
+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 4 ) ) ) ,
193
+ ] ,
194
+ Ok ( None ) ,
195
+ & str ,
196
+ Utf8 ,
197
+ StringArray
198
+ ) ;
199
+ test_function ! (
200
+ RepeatFunc :: new( ) ,
201
+ & [
202
+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "Pg" ) ) ) ) ,
203
+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( None ) ) ,
204
+ ] ,
205
+ Ok ( None ) ,
206
+ & str ,
207
+ Utf8 ,
208
+ StringArray
209
+ ) ;
210
+
151
211
Ok ( ( ) )
152
212
}
153
213
}
0 commit comments