@@ -19,12 +19,11 @@ use std::any::Any;
19
19
use std:: str:: FromStr ;
20
20
use std:: sync:: { Arc , OnceLock } ;
21
21
22
- use arrow:: array:: { Array , ArrayRef , Float64Array } ;
22
+ use arrow:: array:: { Array , ArrayRef , Float64Array , Int32Array } ;
23
23
use arrow:: compute:: kernels:: cast_utils:: IntervalUnit ;
24
- use arrow:: compute:: { binary, cast , date_part, DatePart } ;
24
+ use arrow:: compute:: { binary, date_part, DatePart } ;
25
25
use arrow:: datatypes:: DataType :: {
26
- Date32 , Date64 , Duration , Float64 , Interval , Time32 , Time64 , Timestamp , Utf8 ,
27
- Utf8View ,
26
+ Date32 , Date64 , Duration , Interval , Time32 , Time64 , Timestamp , Utf8 , Utf8View ,
28
27
} ;
29
28
use arrow:: datatypes:: IntervalUnit :: { DayTime , MonthDayNano , YearMonth } ;
30
29
use arrow:: datatypes:: TimeUnit :: { Microsecond , Millisecond , Nanosecond , Second } ;
@@ -36,11 +35,12 @@ use datafusion_common::cast::{
36
35
as_timestamp_microsecond_array, as_timestamp_millisecond_array,
37
36
as_timestamp_nanosecond_array, as_timestamp_second_array,
38
37
} ;
39
- use datafusion_common:: { exec_err, Result , ScalarValue } ;
38
+ use datafusion_common:: { exec_err, internal_err , ExprSchema , Result , ScalarValue } ;
40
39
use datafusion_expr:: scalar_doc_sections:: DOC_SECTION_DATETIME ;
41
40
use datafusion_expr:: TypeSignature :: Exact ;
42
41
use datafusion_expr:: {
43
- ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility , TIMEZONE_WILDCARD ,
42
+ ColumnarValue , Documentation , Expr , ScalarUDFImpl , Signature , Volatility ,
43
+ TIMEZONE_WILDCARD ,
44
44
} ;
45
45
46
46
#[ derive( Debug ) ]
@@ -148,7 +148,21 @@ impl ScalarUDFImpl for DatePartFunc {
148
148
}
149
149
150
150
fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
151
- Ok ( Float64 )
151
+ internal_err ! ( "return_type_from_exprs shoud be called instead" )
152
+ }
153
+
154
+ fn return_type_from_exprs (
155
+ & self ,
156
+ args : & [ Expr ] ,
157
+ _schema : & dyn ExprSchema ,
158
+ _arg_types : & [ DataType ] ,
159
+ ) -> Result < DataType > {
160
+ match & args[ 0 ] {
161
+ Expr :: Literal ( ScalarValue :: Utf8 ( Some ( part) ) ) if is_epoch ( part) => {
162
+ Ok ( DataType :: Float64 )
163
+ }
164
+ _ => Ok ( DataType :: Int32 ) ,
165
+ }
152
166
}
153
167
154
168
fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
@@ -174,35 +188,31 @@ impl ScalarUDFImpl for DatePartFunc {
174
188
ColumnarValue :: Scalar ( scalar) => scalar. to_array ( ) ?,
175
189
} ;
176
190
177
- // to remove quotes at most 2 characters
178
- let part_trim = part. trim_matches ( |c| c == '\'' || c == '\"' ) ;
179
- if ![ 2 , 0 ] . contains ( & ( part. len ( ) - part_trim. len ( ) ) ) {
180
- return exec_err ! ( "Date part '{part}' not supported" ) ;
181
- }
191
+ let part_trim = part_normalization ( part) ;
182
192
183
193
// using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds")
184
194
// and synonyms ( like "ms,msec,msecond,millisecond") to Arrow
185
195
let arr = if let Ok ( interval_unit) = IntervalUnit :: from_str ( part_trim) {
186
196
match interval_unit {
187
- IntervalUnit :: Year => date_part_f64 ( array. as_ref ( ) , DatePart :: Year ) ?,
188
- IntervalUnit :: Month => date_part_f64 ( array. as_ref ( ) , DatePart :: Month ) ?,
189
- IntervalUnit :: Week => date_part_f64 ( array. as_ref ( ) , DatePart :: Week ) ?,
190
- IntervalUnit :: Day => date_part_f64 ( array. as_ref ( ) , DatePart :: Day ) ?,
191
- IntervalUnit :: Hour => date_part_f64 ( array. as_ref ( ) , DatePart :: Hour ) ?,
192
- IntervalUnit :: Minute => date_part_f64 ( array. as_ref ( ) , DatePart :: Minute ) ?,
193
- IntervalUnit :: Second => seconds ( array. as_ref ( ) , Second ) ?,
194
- IntervalUnit :: Millisecond => seconds ( array. as_ref ( ) , Millisecond ) ?,
195
- IntervalUnit :: Microsecond => seconds ( array. as_ref ( ) , Microsecond ) ?,
196
- IntervalUnit :: Nanosecond => seconds ( array. as_ref ( ) , Nanosecond ) ?,
197
+ IntervalUnit :: Year => date_part ( array. as_ref ( ) , DatePart :: Year ) ?,
198
+ IntervalUnit :: Month => date_part ( array. as_ref ( ) , DatePart :: Month ) ?,
199
+ IntervalUnit :: Week => date_part ( array. as_ref ( ) , DatePart :: Week ) ?,
200
+ IntervalUnit :: Day => date_part ( array. as_ref ( ) , DatePart :: Day ) ?,
201
+ IntervalUnit :: Hour => date_part ( array. as_ref ( ) , DatePart :: Hour ) ?,
202
+ IntervalUnit :: Minute => date_part ( array. as_ref ( ) , DatePart :: Minute ) ?,
203
+ IntervalUnit :: Second => seconds_as_i32 ( array. as_ref ( ) , Second ) ?,
204
+ IntervalUnit :: Millisecond => seconds_as_i32 ( array. as_ref ( ) , Millisecond ) ?,
205
+ IntervalUnit :: Microsecond => seconds_as_i32 ( array. as_ref ( ) , Microsecond ) ?,
206
+ IntervalUnit :: Nanosecond => seconds_as_i32 ( array. as_ref ( ) , Nanosecond ) ?,
197
207
// century and decade are not supported by `DatePart`, although they are supported in postgres
198
208
_ => return exec_err ! ( "Date part '{part}' not supported" ) ,
199
209
}
200
210
} else {
201
211
// special cases that can be extracted (in postgres) but are not interval units
202
212
match part_trim. to_lowercase ( ) . as_str ( ) {
203
- "qtr" | "quarter" => date_part_f64 ( array. as_ref ( ) , DatePart :: Quarter ) ?,
204
- "doy" => date_part_f64 ( array. as_ref ( ) , DatePart :: DayOfYear ) ?,
205
- "dow" => date_part_f64 ( array. as_ref ( ) , DatePart :: DayOfWeekSunday0 ) ?,
213
+ "qtr" | "quarter" => date_part ( array. as_ref ( ) , DatePart :: Quarter ) ?,
214
+ "doy" => date_part ( array. as_ref ( ) , DatePart :: DayOfYear ) ?,
215
+ "dow" => date_part ( array. as_ref ( ) , DatePart :: DayOfWeekSunday0 ) ?,
206
216
"epoch" => epoch ( array. as_ref ( ) ) ?,
207
217
_ => return exec_err ! ( "Date part '{part}' not supported" ) ,
208
218
}
@@ -223,6 +233,18 @@ impl ScalarUDFImpl for DatePartFunc {
223
233
}
224
234
}
225
235
236
+ fn is_epoch ( part : & str ) -> bool {
237
+ let part = part_normalization ( part) ;
238
+ matches ! ( part. to_lowercase( ) . as_str( ) , "epoch" )
239
+ }
240
+
241
+ // Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error
242
+ fn part_normalization ( part : & str ) -> & str {
243
+ part. strip_prefix ( |c| c == '\'' || c == '\"' )
244
+ . and_then ( |s| s. strip_suffix ( |c| c == '\'' || c == '\"' ) )
245
+ . unwrap_or ( part)
246
+ }
247
+
226
248
static DOCUMENTATION : OnceLock < Documentation > = OnceLock :: new ( ) ;
227
249
228
250
fn get_date_part_doc ( ) -> & ' static Documentation {
@@ -261,14 +283,63 @@ fn get_date_part_doc() -> &'static Documentation {
261
283
} )
262
284
}
263
285
264
- /// Invoke [`date_part`] and cast the result to Float64
265
- fn date_part_f64 ( array : & dyn Array , part : DatePart ) -> Result < ArrayRef > {
266
- Ok ( cast ( date_part ( array, part) ?. as_ref ( ) , & Float64 ) ?)
286
+ /// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the
287
+ /// result to a total number of seconds, milliseconds, microseconds or
288
+ /// nanoseconds
289
+ fn seconds_as_i32 ( array : & dyn Array , unit : TimeUnit ) -> Result < ArrayRef > {
290
+ // Nanosecond is neither supported in Postgres nor DuckDB, to avoid to deal with overflow and precision issue we don't support nanosecond
291
+ if unit == Nanosecond {
292
+ return internal_err ! ( "unit {unit:?} not supported" ) ;
293
+ }
294
+
295
+ let conversion_factor = match unit {
296
+ Second => 1_000_000_000 ,
297
+ Millisecond => 1_000_000 ,
298
+ Microsecond => 1_000 ,
299
+ Nanosecond => 1 ,
300
+ } ;
301
+
302
+ let second_factor = match unit {
303
+ Second => 1 ,
304
+ Millisecond => 1_000 ,
305
+ Microsecond => 1_000_000 ,
306
+ Nanosecond => 1_000_000_000 ,
307
+ } ;
308
+
309
+ let secs = date_part ( array, DatePart :: Second ) ?;
310
+ // This assumes array is primitive and not a dictionary
311
+ let secs = as_int32_array ( secs. as_ref ( ) ) ?;
312
+ let subsecs = date_part ( array, DatePart :: Nanosecond ) ?;
313
+ let subsecs = as_int32_array ( subsecs. as_ref ( ) ) ?;
314
+
315
+ // Special case where there are no nulls.
316
+ if subsecs. null_count ( ) == 0 {
317
+ let r: Int32Array = binary ( secs, subsecs, |secs, subsecs| {
318
+ secs * second_factor + ( subsecs % 1_000_000_000 ) / conversion_factor
319
+ } ) ?;
320
+ Ok ( Arc :: new ( r) )
321
+ } else {
322
+ // Nulls in secs are preserved, nulls in subsecs are treated as zero to account for the case
323
+ // where the number of nanoseconds overflows.
324
+ let r: Int32Array = secs
325
+ . iter ( )
326
+ . zip ( subsecs)
327
+ . map ( |( secs, subsecs) | {
328
+ secs. map ( |secs| {
329
+ let subsecs = subsecs. unwrap_or ( 0 ) ;
330
+ secs * second_factor + ( subsecs % 1_000_000_000 ) / conversion_factor
331
+ } )
332
+ } )
333
+ . collect ( ) ;
334
+ Ok ( Arc :: new ( r) )
335
+ }
267
336
}
268
337
269
338
/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the
270
339
/// result to a total number of seconds, milliseconds, microseconds or
271
340
/// nanoseconds
341
+ ///
342
+ /// Given epoch return f64, this is a duplicated function to optimize for f64 type
272
343
fn seconds ( array : & dyn Array , unit : TimeUnit ) -> Result < ArrayRef > {
273
344
let sf = match unit {
274
345
Second => 1_f64 ,
0 commit comments