Skip to content

Commit 3f0fb4a

Browse files
authored
Upgrade sqlparser-rs to 0.51.0, support new interval logic from sqlparse-rs (#12222)
* support new interval logic from sqlparse-rs * uprev sqlparser-rs branch * use sqlparser 51 * better extract logic and interval testing * revert unnecessary changes * revert unnecessary changes, more * cleanup * fix last failing test :fingerscrossed:
1 parent f48e0b2 commit 3f0fb4a

File tree

10 files changed

+215
-222
lines changed

10 files changed

+215
-222
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ rand = "0.8"
137137
regex = "1.8"
138138
rstest = "0.22.0"
139139
serde_json = "1"
140-
sqlparser = { version = "0.50.0", features = ["visitor"] }
140+
sqlparser = { version = "0.51.0", features = ["visitor"] }
141141
tempfile = "3"
142142
thiserror = "1.0.44"
143143
tokio = { version = "1.36", features = ["macros", "rt", "sync"] }

datafusion-cli/Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/functions/src/datetime/date_part.rs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
// under the License.
1717

1818
use std::any::Any;
19+
use std::str::FromStr;
1920
use std::sync::Arc;
2021

2122
use arrow::array::{Array, ArrayRef, Float64Array};
23+
use arrow::compute::kernels::cast_utils::IntervalUnit;
2224
use arrow::compute::{binary, cast, date_part, DatePart};
2325
use arrow::datatypes::DataType::{
2426
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, Utf8View,
@@ -161,22 +163,32 @@ impl ScalarUDFImpl for DatePartFunc {
161163
return exec_err!("Date part '{part}' not supported");
162164
}
163165

164-
let arr = match part_trim.to_lowercase().as_str() {
165-
"year" => date_part_f64(array.as_ref(), DatePart::Year)?,
166-
"quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
167-
"month" => date_part_f64(array.as_ref(), DatePart::Month)?,
168-
"week" => date_part_f64(array.as_ref(), DatePart::Week)?,
169-
"day" => date_part_f64(array.as_ref(), DatePart::Day)?,
170-
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
171-
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
172-
"hour" => date_part_f64(array.as_ref(), DatePart::Hour)?,
173-
"minute" => date_part_f64(array.as_ref(), DatePart::Minute)?,
174-
"second" => seconds(array.as_ref(), Second)?,
175-
"millisecond" => seconds(array.as_ref(), Millisecond)?,
176-
"microsecond" => seconds(array.as_ref(), Microsecond)?,
177-
"nanosecond" => seconds(array.as_ref(), Nanosecond)?,
178-
"epoch" => epoch(array.as_ref())?,
179-
_ => return exec_err!("Date part '{part}' not supported"),
166+
// using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds")
167+
// and synonyms ( like "ms,msec,msecond,millisecond") to Arrow
168+
let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) {
169+
match interval_unit {
170+
IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?,
171+
IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?,
172+
IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?,
173+
IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?,
174+
IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?,
175+
IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?,
176+
IntervalUnit::Second => seconds(array.as_ref(), Second)?,
177+
IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?,
178+
IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?,
179+
IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?,
180+
// century and decade are not supported by `DatePart`, although they are supported in postgres
181+
_ => return exec_err!("Date part '{part}' not supported"),
182+
}
183+
} else {
184+
// special cases that can be extracted (in postgres) but are not interval units
185+
match part_trim.to_lowercase().as_str() {
186+
"qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?,
187+
"doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?,
188+
"dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?,
189+
"epoch" => epoch(array.as_ref())?,
190+
_ => return exec_err!("Date part '{part}' not supported"),
191+
}
180192
};
181193

182194
Ok(if is_scalar {

datafusion/sql/src/expr/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
201201
}
202202

203203
SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema),
204-
SQLExpr::Interval(interval) => {
205-
self.sql_interval_to_expr(false, interval, schema, planner_context)
206-
}
204+
SQLExpr::Interval(interval) => self.sql_interval_to_expr(false, interval),
207205
SQLExpr::Identifier(id) => {
208206
self.sql_identifier_to_expr(id, schema, planner_context)
209207
}

datafusion/sql/src/expr/unary_op.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
4343
self.parse_sql_number(&n, true)
4444
}
4545
SQLExpr::Interval(interval) => {
46-
self.sql_interval_to_expr(true, interval, schema, planner_context)
46+
self.sql_interval_to_expr(true, interval)
4747
}
4848
// not a literal, apply negative operator on expression
4949
_ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(

datafusion/sql/src/expr/value.rs

Lines changed: 72 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use datafusion_expr::expr::{BinaryExpr, Placeholder};
2626
use datafusion_expr::planner::PlannerResult;
2727
use datafusion_expr::{lit, Expr, Operator};
2828
use log::debug;
29-
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value};
29+
use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, UnaryOperator, Value};
3030
use sqlparser::parser::ParserError::ParserError;
3131
use std::borrow::Cow;
3232

@@ -168,12 +168,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
168168

169169
/// Convert a SQL interval expression to a DataFusion logical plan
170170
/// expression
171+
#[allow(clippy::only_used_in_recursion)]
171172
pub(super) fn sql_interval_to_expr(
172173
&self,
173174
negative: bool,
174175
interval: Interval,
175-
schema: &DFSchema,
176-
planner_context: &mut PlannerContext,
177176
) -> Result<Expr> {
178177
if interval.leading_precision.is_some() {
179178
return not_impl_err!(
@@ -196,127 +195,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
196195
);
197196
}
198197

199-
// Only handle string exprs for now
200-
let value = match *interval.value {
201-
SQLExpr::Value(
202-
Value::SingleQuotedString(s) | Value::DoubleQuotedString(s),
203-
) => {
204-
if negative {
205-
format!("-{s}")
206-
} else {
207-
s
208-
}
209-
}
210-
// Support expressions like `interval '1 month' + date/timestamp`.
211-
// Such expressions are parsed like this by sqlparser-rs
212-
//
213-
// Interval
214-
// BinaryOp
215-
// Value(StringLiteral)
216-
// Cast
217-
// Value(StringLiteral)
218-
//
219-
// This code rewrites them to the following:
220-
//
221-
// BinaryOp
222-
// Interval
223-
// Value(StringLiteral)
224-
// Cast
225-
// Value(StringLiteral)
226-
SQLExpr::BinaryOp { left, op, right } => {
227-
let df_op = match op {
228-
BinaryOperator::Plus => Operator::Plus,
229-
BinaryOperator::Minus => Operator::Minus,
230-
BinaryOperator::Eq => Operator::Eq,
231-
BinaryOperator::NotEq => Operator::NotEq,
232-
BinaryOperator::Gt => Operator::Gt,
233-
BinaryOperator::GtEq => Operator::GtEq,
234-
BinaryOperator::Lt => Operator::Lt,
235-
BinaryOperator::LtEq => Operator::LtEq,
236-
_ => {
237-
return not_impl_err!("Unsupported interval operator: {op:?}");
238-
}
239-
};
240-
match (
241-
interval.leading_field.as_ref(),
242-
left.as_ref(),
243-
right.as_ref(),
244-
) {
245-
(_, _, SQLExpr::Value(_)) => {
246-
let left_expr = self.sql_interval_to_expr(
247-
negative,
248-
Interval {
249-
value: left,
250-
leading_field: interval.leading_field.clone(),
251-
leading_precision: None,
252-
last_field: None,
253-
fractional_seconds_precision: None,
254-
},
255-
schema,
256-
planner_context,
257-
)?;
258-
let right_expr = self.sql_interval_to_expr(
259-
false,
260-
Interval {
261-
value: right,
262-
leading_field: interval.leading_field,
263-
leading_precision: None,
264-
last_field: None,
265-
fractional_seconds_precision: None,
266-
},
267-
schema,
268-
planner_context,
269-
)?;
270-
return Ok(Expr::BinaryExpr(BinaryExpr::new(
271-
Box::new(left_expr),
272-
df_op,
273-
Box::new(right_expr),
274-
)));
275-
}
276-
// In this case, the left node is part of the interval
277-
// expr and the right node is an independent expr.
278-
//
279-
// Leading field is not supported when the right operand
280-
// is not a value.
281-
(None, _, _) => {
282-
let left_expr = self.sql_interval_to_expr(
283-
negative,
284-
Interval {
285-
value: left,
286-
leading_field: None,
287-
leading_precision: None,
288-
last_field: None,
289-
fractional_seconds_precision: None,
290-
},
291-
schema,
292-
planner_context,
293-
)?;
294-
let right_expr = self.sql_expr_to_logical_expr(
295-
*right,
296-
schema,
297-
planner_context,
298-
)?;
299-
return Ok(Expr::BinaryExpr(BinaryExpr::new(
300-
Box::new(left_expr),
301-
df_op,
302-
Box::new(right_expr),
303-
)));
304-
}
305-
_ => {
306-
let value = SQLExpr::BinaryOp { left, op, right };
307-
return not_impl_err!(
308-
"Unsupported interval argument. Expected string literal, got: {value:?}"
309-
);
310-
}
198+
if let SQLExpr::BinaryOp { left, op, right } = *interval.value {
199+
let df_op = match op {
200+
BinaryOperator::Plus => Operator::Plus,
201+
BinaryOperator::Minus => Operator::Minus,
202+
_ => {
203+
return not_impl_err!("Unsupported interval operator: {op:?}");
311204
}
312-
}
313-
_ => {
314-
return not_impl_err!(
315-
"Unsupported interval argument. Expected string literal, got: {:?}",
316-
interval.value
317-
);
318-
}
319-
};
205+
};
206+
let left_expr = self.sql_interval_to_expr(
207+
negative,
208+
Interval {
209+
value: left,
210+
leading_field: interval.leading_field.clone(),
211+
leading_precision: None,
212+
last_field: None,
213+
fractional_seconds_precision: None,
214+
},
215+
)?;
216+
let right_expr = self.sql_interval_to_expr(
217+
false,
218+
Interval {
219+
value: right,
220+
leading_field: interval.leading_field,
221+
leading_precision: None,
222+
last_field: None,
223+
fractional_seconds_precision: None,
224+
},
225+
)?;
226+
return Ok(Expr::BinaryExpr(BinaryExpr::new(
227+
Box::new(left_expr),
228+
df_op,
229+
Box::new(right_expr),
230+
)));
231+
}
232+
233+
let value = interval_literal(*interval.value, negative)?;
320234

321235
let value = if has_units(&value) {
322236
// If the interval already contains a unit
@@ -343,6 +257,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
343257
}
344258
}
345259

260+
fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result<String> {
261+
let s = match interval_value {
262+
SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => s,
263+
SQLExpr::Value(Value::Number(ref v, long)) => {
264+
if long {
265+
return not_impl_err!(
266+
"Unsupported interval argument. Long number not supported: {interval_value:?}"
267+
);
268+
} else {
269+
v.to_string()
270+
}
271+
}
272+
SQLExpr::UnaryOp { op, expr } => {
273+
let negative = match op {
274+
UnaryOperator::Minus => !negative,
275+
UnaryOperator::Plus => negative,
276+
_ => {
277+
return not_impl_err!(
278+
"Unsupported SQL unary operator in interval {op:?}"
279+
);
280+
}
281+
};
282+
interval_literal(*expr, negative)?
283+
}
284+
_ => {
285+
return not_impl_err!("Unsupported interval argument. Expected string literal or number, got: {interval_value:?}");
286+
}
287+
};
288+
if negative {
289+
Ok(format!("-{s}"))
290+
} else {
291+
Ok(s)
292+
}
293+
}
294+
346295
// TODO make interval parsing better in arrow-rs / expose `IntervalType`
347296
fn has_units(val: &str) -> bool {
348297
let val = val.to_lowercase();

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -495,11 +495,17 @@ fn test_table_references_in_plan_to_sql() {
495495
assert_eq!(format!("{}", sql), expected_sql)
496496
}
497497

498-
test("catalog.schema.table", "SELECT catalog.\"schema\".\"table\".id, catalog.\"schema\".\"table\".\"value\" FROM catalog.\"schema\".\"table\"");
499-
test("schema.table", "SELECT \"schema\".\"table\".id, \"schema\".\"table\".\"value\" FROM \"schema\".\"table\"");
498+
test(
499+
"catalog.schema.table",
500+
r#"SELECT "catalog"."schema"."table".id, "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""#,
501+
);
502+
test(
503+
"schema.table",
504+
r#"SELECT "schema"."table".id, "schema"."table"."value" FROM "schema"."table""#,
505+
);
500506
test(
501507
"table",
502-
"SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
508+
r#"SELECT "table".id, "table"."value" FROM "table""#,
503509
);
504510
}
505511

@@ -521,10 +527,10 @@ fn test_table_scan_with_no_projection_in_plan_to_sql() {
521527

522528
test(
523529
"catalog.schema.table",
524-
"SELECT * FROM catalog.\"schema\".\"table\"",
530+
r#"SELECT * FROM "catalog"."schema"."table""#,
525531
);
526-
test("schema.table", "SELECT * FROM \"schema\".\"table\"");
527-
test("table", "SELECT * FROM \"table\"");
532+
test("schema.table", r#"SELECT * FROM "schema"."table""#);
533+
test("table", r#"SELECT * FROM "table""#);
528534
}
529535

530536
#[test]

datafusion/sqllogictest/test_files/expr.slt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,16 @@ SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanose
13551355
----
13561356
50.123456789
13571357

1358+
query R
1359+
select extract(second from '2024-08-09T12:13:14')
1360+
----
1361+
14
1362+
1363+
query R
1364+
select extract(seconds from '2024-08-09T12:13:14')
1365+
----
1366+
14
1367+
13581368
query R
13591369
SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
13601370
----
@@ -1381,6 +1391,11 @@ SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(N
13811391
----
13821392
50123456.789000005
13831393

1394+
query R
1395+
SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
1396+
----
1397+
50123456.789000005
1398+
13841399
query R
13851400
SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)'))
13861401
----

0 commit comments

Comments
 (0)