Skip to content

Commit 58761ac

Browse files
Return int32 for integer type date part (#13466)
* return int for integar date part Signed-off-by: jayzhan211 <[email protected]> * fix tpch test Signed-off-by: jayzhan211 <[email protected]> * type test Signed-off-by: jayzhan211 <[email protected]> * Update datafusion/functions/src/datetime/date_part.rs Co-authored-by: Daniël Heres <[email protected]> * fix name Signed-off-by: jayzhan211 <[email protected]> * use int for second Signed-off-by: jayzhan211 <[email protected]> * rm dot Signed-off-by: Jay Zhan <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]> Signed-off-by: Jay Zhan <[email protected]> Co-authored-by: Daniël Heres <[email protected]>
1 parent dd4fa79 commit 58761ac

File tree

8 files changed

+282
-243
lines changed

8 files changed

+282
-243
lines changed

datafusion/functions/src/datetime/date_part.rs

Lines changed: 99 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ use std::any::Any;
1919
use std::str::FromStr;
2020
use std::sync::{Arc, OnceLock};
2121

22-
use arrow::array::{Array, ArrayRef, Float64Array};
22+
use arrow::array::{Array, ArrayRef, Float64Array, Int32Array};
2323
use arrow::compute::kernels::cast_utils::IntervalUnit;
24-
use arrow::compute::{binary, cast, date_part, DatePart};
24+
use arrow::compute::{binary, date_part, DatePart};
2525
use arrow::datatypes::DataType::{
26-
Date32, Date64, Duration, Float64, Interval, Time32, Time64, Timestamp, Utf8,
27-
Utf8View,
26+
Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, Utf8, Utf8View,
2827
};
2928
use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano, YearMonth};
3029
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
@@ -36,11 +35,12 @@ use datafusion_common::cast::{
3635
as_timestamp_microsecond_array, as_timestamp_millisecond_array,
3736
as_timestamp_nanosecond_array, as_timestamp_second_array,
3837
};
39-
use datafusion_common::{exec_err, Result, ScalarValue};
38+
use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue};
4039
use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME;
4140
use datafusion_expr::TypeSignature::Exact;
4241
use datafusion_expr::{
43-
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
42+
ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility,
43+
TIMEZONE_WILDCARD,
4444
};
4545

4646
#[derive(Debug)]
@@ -148,7 +148,21 @@ impl ScalarUDFImpl for DatePartFunc {
148148
}
149149

150150
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
151-
Ok(Float64)
151+
internal_err!("return_type_from_exprs shoud be called instead")
152+
}
153+
154+
fn return_type_from_exprs(
155+
&self,
156+
args: &[Expr],
157+
_schema: &dyn ExprSchema,
158+
_arg_types: &[DataType],
159+
) -> Result<DataType> {
160+
match &args[0] {
161+
Expr::Literal(ScalarValue::Utf8(Some(part))) if is_epoch(part) => {
162+
Ok(DataType::Float64)
163+
}
164+
_ => Ok(DataType::Int32),
165+
}
152166
}
153167

154168
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
@@ -174,35 +188,31 @@ impl ScalarUDFImpl for DatePartFunc {
174188
ColumnarValue::Scalar(scalar) => scalar.to_array()?,
175189
};
176190

177-
// to remove quotes at most 2 characters
178-
let part_trim = part.trim_matches(|c| c == '\'' || c == '\"');
179-
if ![2, 0].contains(&(part.len() - part_trim.len())) {
180-
return exec_err!("Date part '{part}' not supported");
181-
}
191+
let part_trim = part_normalization(part);
182192

183193
// using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds")
184194
// and synonyms ( like "ms,msec,msecond,millisecond") to Arrow
185195
let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) {
186196
match interval_unit {
187-
IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?,
188-
IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?,
189-
IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?,
190-
IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?,
191-
IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?,
192-
IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?,
193-
IntervalUnit::Second => seconds(array.as_ref(), Second)?,
194-
IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?,
195-
IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?,
196-
IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?,
197+
IntervalUnit::Year => date_part(array.as_ref(), DatePart::Year)?,
198+
IntervalUnit::Month => date_part(array.as_ref(), DatePart::Month)?,
199+
IntervalUnit::Week => date_part(array.as_ref(), DatePart::Week)?,
200+
IntervalUnit::Day => date_part(array.as_ref(), DatePart::Day)?,
201+
IntervalUnit::Hour => date_part(array.as_ref(), DatePart::Hour)?,
202+
IntervalUnit::Minute => date_part(array.as_ref(), DatePart::Minute)?,
203+
IntervalUnit::Second => seconds_as_i32(array.as_ref(), Second)?,
204+
IntervalUnit::Millisecond => seconds_as_i32(array.as_ref(), Millisecond)?,
205+
IntervalUnit::Microsecond => seconds_as_i32(array.as_ref(), Microsecond)?,
206+
IntervalUnit::Nanosecond => seconds_as_i32(array.as_ref(), Nanosecond)?,
197207
// century and decade are not supported by `DatePart`, although they are supported in postgres
198208
_ => return exec_err!("Date part '{part}' not supported"),
199209
}
200210
} else {
201211
// special cases that can be extracted (in postgres) but are not interval units
202212
match part_trim.to_lowercase().as_str() {
203-
"qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
204-
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
205-
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
213+
"qtr" | "quarter" => date_part(array.as_ref(), DatePart::Quarter)?,
214+
"doy" => date_part(array.as_ref(), DatePart::DayOfYear)?,
215+
"dow" => date_part(array.as_ref(), DatePart::DayOfWeekSunday0)?,
206216
"epoch" => epoch(array.as_ref())?,
207217
_ => return exec_err!("Date part '{part}' not supported"),
208218
}
@@ -223,6 +233,18 @@ impl ScalarUDFImpl for DatePartFunc {
223233
}
224234
}
225235

236+
fn is_epoch(part: &str) -> bool {
237+
let part = part_normalization(part);
238+
matches!(part.to_lowercase().as_str(), "epoch")
239+
}
240+
241+
// Try to remove quote if exist, if the quote is invalid, return original string and let the downstream function handle the error
242+
fn part_normalization(part: &str) -> &str {
243+
part.strip_prefix(|c| c == '\'' || c == '\"')
244+
.and_then(|s| s.strip_suffix(|c| c == '\'' || c == '\"'))
245+
.unwrap_or(part)
246+
}
247+
226248
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
227249

228250
fn get_date_part_doc() -> &'static Documentation {
@@ -261,14 +283,63 @@ fn get_date_part_doc() -> &'static Documentation {
261283
})
262284
}
263285

264-
/// Invoke [`date_part`] and cast the result to Float64
265-
fn date_part_f64(array: &dyn Array, part: DatePart) -> Result<ArrayRef> {
266-
Ok(cast(date_part(array, part)?.as_ref(), &Float64)?)
286+
/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the
287+
/// result to a total number of seconds, milliseconds, microseconds or
288+
/// nanoseconds
289+
fn seconds_as_i32(array: &dyn Array, unit: TimeUnit) -> Result<ArrayRef> {
290+
// Nanosecond is neither supported in Postgres nor DuckDB, to avoid to deal with overflow and precision issue we don't support nanosecond
291+
if unit == Nanosecond {
292+
return internal_err!("unit {unit:?} not supported");
293+
}
294+
295+
let conversion_factor = match unit {
296+
Second => 1_000_000_000,
297+
Millisecond => 1_000_000,
298+
Microsecond => 1_000,
299+
Nanosecond => 1,
300+
};
301+
302+
let second_factor = match unit {
303+
Second => 1,
304+
Millisecond => 1_000,
305+
Microsecond => 1_000_000,
306+
Nanosecond => 1_000_000_000,
307+
};
308+
309+
let secs = date_part(array, DatePart::Second)?;
310+
// This assumes array is primitive and not a dictionary
311+
let secs = as_int32_array(secs.as_ref())?;
312+
let subsecs = date_part(array, DatePart::Nanosecond)?;
313+
let subsecs = as_int32_array(subsecs.as_ref())?;
314+
315+
// Special case where there are no nulls.
316+
if subsecs.null_count() == 0 {
317+
let r: Int32Array = binary(secs, subsecs, |secs, subsecs| {
318+
secs * second_factor + (subsecs % 1_000_000_000) / conversion_factor
319+
})?;
320+
Ok(Arc::new(r))
321+
} else {
322+
// Nulls in secs are preserved, nulls in subsecs are treated as zero to account for the case
323+
// where the number of nanoseconds overflows.
324+
let r: Int32Array = secs
325+
.iter()
326+
.zip(subsecs)
327+
.map(|(secs, subsecs)| {
328+
secs.map(|secs| {
329+
let subsecs = subsecs.unwrap_or(0);
330+
secs * second_factor + (subsecs % 1_000_000_000) / conversion_factor
331+
})
332+
})
333+
.collect();
334+
Ok(Arc::new(r))
335+
}
267336
}
268337

269338
/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the
270339
/// result to a total number of seconds, milliseconds, microseconds or
271340
/// nanoseconds
341+
///
342+
/// Given epoch return f64, this is a duplicated function to optimize for f64 type
272343
fn seconds(array: &dyn Array, unit: TimeUnit) -> Result<ArrayRef> {
273344
let sf = match unit {
274345
Second => 1_f64,

datafusion/sqllogictest/test_files/clickbench.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ SELECT "UserID", "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", "SearchPh
136136
519640690937130534 (empty) 2
137137
7418527520126366595 (empty) 1
138138

139-
query IRTI rowsort
139+
query IITI rowsort
140140
SELECT "UserID", extract(minute FROM to_timestamp_seconds("EventTime")) AS m, "SearchPhrase", COUNT(*) FROM hits GROUP BY "UserID", m, "SearchPhrase" ORDER BY COUNT(*) DESC LIMIT 10;
141141
----
142142
-2461439046089301801 18 (empty) 1

0 commit comments

Comments
 (0)