Skip to content

Commit e1f866e

Browse files
authored
Add optimizer test for simplifying predicates on timestamps (#3939)
1 parent b5c23c2 commit e1f866e

File tree

1 file changed

+80
-30
lines changed

1 file changed

+80
-30
lines changed

datafusion/optimizer/tests/integration-test.rs

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
18+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
19+
use chrono::{DateTime, NaiveDateTime, Utc};
1920
use datafusion_common::{DataFusionError, Result};
2021
use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource};
2122
use datafusion_optimizer::optimizer::Optimizer;
@@ -87,11 +88,11 @@ fn case_when_aggregate() -> Result<()> {
8788

8889
#[test]
8990
fn unsigned_target_type() -> Result<()> {
90-
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
91+
let sql = "SELECT col_utf8 FROM test WHERE col_uint32 > 0";
9192
let plan = test_sql(sql)?;
92-
let expected = "Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64\
93-
\n Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\
94-
\n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]";
93+
let expected = "Projection: test.col_utf8\
94+
\n Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\
95+
\n TableScan: test projection=[col_uint32, col_utf8]";
9596
assert_eq!(expected, format!("{:?}", plan));
9697
Ok(())
9798
}
@@ -111,46 +112,46 @@ fn distribute_by() -> Result<()> {
111112
#[test]
112113
fn semi_join_with_join_filter() -> Result<()> {
113114
// regression test for https://github.com/apache/arrow-datafusion/issues/2888
114-
let sql = "SELECT * FROM test WHERE EXISTS (\
115-
SELECT * FROM test t2 WHERE test.col_int32 = t2.col_int32 \
116-
AND test.col_uint32 != t2.col_uint32)";
115+
let sql = "SELECT col_utf8 FROM test WHERE EXISTS (\
116+
SELECT col_utf8 FROM test t2 WHERE test.col_int32 = t2.col_int32 \
117+
AND test.col_uint32 != t2.col_uint32)";
117118
let plan = test_sql(sql)?;
118-
let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64
119-
Semi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32
120-
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]
121-
SubqueryAlias: t2
122-
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#;
119+
let expected = "Projection: test.col_utf8\
120+
\n Semi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\
121+
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
122+
\n SubqueryAlias: t2\
123+
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]";
123124
assert_eq!(expected, format!("{:?}", plan));
124125
Ok(())
125126
}
126127

127128
#[test]
128129
fn anti_join_with_join_filter() -> Result<()> {
129130
// regression test for https://github.com/apache/arrow-datafusion/issues/2888
130-
let sql = "SELECT * FROM test WHERE NOT EXISTS (\
131-
SELECT * FROM test t2 WHERE test.col_int32 = t2.col_int32 \
132-
AND test.col_uint32 != t2.col_uint32)";
131+
let sql = "SELECT col_utf8 FROM test WHERE NOT EXISTS (\
132+
SELECT col_utf8 FROM test t2 WHERE test.col_int32 = t2.col_int32 \
133+
AND test.col_uint32 != t2.col_uint32)";
133134
let plan = test_sql(sql)?;
134-
let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64
135-
Anti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32
136-
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]
137-
SubqueryAlias: t2
138-
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#;
135+
let expected = "Projection: test.col_utf8\
136+
\n Anti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\
137+
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
138+
\n SubqueryAlias: t2\
139+
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]";
139140
assert_eq!(expected, format!("{:?}", plan));
140141
Ok(())
141142
}
142143

143144
#[test]
144145
fn where_exists_distinct() -> Result<()> {
145146
// regression test for https://github.com/apache/arrow-datafusion/issues/3724
146-
let sql = "SELECT * FROM test WHERE EXISTS (\
147-
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
147+
let sql = "SELECT col_int32 FROM test WHERE EXISTS (\
148+
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
148149
let plan = test_sql(sql)?;
149-
let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64
150-
Semi Join: test.col_int32 = t2.col_int32
151-
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]
152-
SubqueryAlias: t2
153-
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#;
150+
let expected = "Projection: test.col_int32\
151+
\n Semi Join: test.col_int32 = t2.col_int32\
152+
\n TableScan: test projection=[col_int32]\
153+
\n SubqueryAlias: t2\
154+
\n TableScan: test projection=[col_int32]";
154155
assert_eq!(expected, format!("{:?}", plan));
155156
Ok(())
156157
}
@@ -225,6 +226,38 @@ fn concat_ws_literals() -> Result<()> {
225226
Ok(())
226227
}
227228

229+
#[test]
230+
#[ignore]
231+
// https://github.com/apache/arrow-datafusion/issues/3938
232+
fn timestamp_nano_ts_none_predicates() -> Result<()> {
233+
let sql = "SELECT col_int32
234+
FROM test
235+
WHERE col_ts_nano_none < (now() - interval '1 hour')";
236+
let plan = test_sql(sql)?;
237+
// a scan should have the now()... predicate folded to a single
238+
// constant and compared to the column without a cast so it can be
239+
// pushed down / pruned
240+
let expected = "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"UTC\"))\
241+
\n TableScan: test projection=[col_int32, col_ts_nano_none]";
242+
assert_eq!(expected, format!("{:?}", plan));
243+
Ok(())
244+
}
245+
246+
#[test]
247+
fn timestamp_nano_ts_utc_predicates() -> Result<()> {
248+
let sql = "SELECT col_int32
249+
FROM test
250+
WHERE col_ts_nano_utc < (now() - interval '1 hour')";
251+
let plan = test_sql(sql)?;
252+
// a scan should have the now()... predicate folded to a single
253+
// constant and compared to the column without a cast so it can be
254+
// pushed down / pruned
255+
let expected = "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"UTC\"))\
256+
\n TableScan: test projection=[col_int32, col_ts_nano_utc]";
257+
assert_eq!(expected, format!("{:?}", plan));
258+
Ok(())
259+
}
260+
228261
fn test_sql(sql: &str) -> Result<LogicalPlan> {
229262
// parse the SQL
230263
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
@@ -236,9 +269,14 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
236269
let sql_to_rel = SqlToRel::new(&schema_provider);
237270
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
238271

239-
// optimize the logical plan
240-
let mut config = OptimizerConfig::new().with_skip_failing_rules(false);
272+
// hard code the return value of now()
273+
let now_time =
274+
DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(1666615693, 0), Utc);
275+
let mut config = OptimizerConfig::new()
276+
.with_skip_failing_rules(false)
277+
.with_query_execution_start_time(now_time);
241278
let optimizer = Optimizer::new(&config);
279+
// optimize the logical plan
242280
optimizer.optimize(&plan, &mut config, &observe)
243281
}
244282

@@ -258,6 +296,18 @@ impl ContextProvider for MySchemaProvider {
258296
Field::new("col_utf8", DataType::Utf8, true),
259297
Field::new("col_date32", DataType::Date32, true),
260298
Field::new("col_date64", DataType::Date64, true),
299+
// timestamp with no timezone
300+
Field::new(
301+
"col_ts_nano_none",
302+
DataType::Timestamp(TimeUnit::Nanosecond, None),
303+
true,
304+
),
305+
// timestamp with UTC timezone
306+
Field::new(
307+
"col_ts_nano_utc",
308+
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
309+
true,
310+
),
261311
],
262312
HashMap::new(),
263313
);

0 commit comments

Comments
 (0)