Skip to content

Commit 4e1f839

Browse files
authored
Introduce TypePlanner for customizing type planning (#13294)
* introduce `plan_data_type` for ExprPlanner * implement TypePlanner trait instead of extending ExprPlanner * enhance the document
1 parent cc96026 commit 4e1f839

File tree

6 files changed

+211
-10
lines changed

6 files changed

+211
-10
lines changed

datafusion/core/src/execution/context/mod.rs

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1788,22 +1788,24 @@ impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {
17881788

17891789
#[cfg(test)]
17901790
mod tests {
1791-
use std::env;
1792-
use std::path::PathBuf;
1793-
17941791
use super::{super::options::CsvReadOptions, *};
17951792
use crate::assert_batches_eq;
17961793
use crate::execution::memory_pool::MemoryConsumer;
17971794
use crate::execution::runtime_env::RuntimeEnvBuilder;
17981795
use crate::test;
17991796
use crate::test_util::{plan_and_collect, populate_csv_partitions};
1797+
use arrow_schema::{DataType, TimeUnit};
1798+
use std::env;
1799+
use std::path::PathBuf;
18001800

18011801
use datafusion_common_runtime::SpawnedTask;
18021802

18031803
use crate::catalog::SchemaProvider;
18041804
use crate::execution::session_state::SessionStateBuilder;
18051805
use crate::physical_planner::PhysicalPlanner;
18061806
use async_trait::async_trait;
1807+
use datafusion_expr::planner::TypePlanner;
1808+
use sqlparser::ast;
18071809
use tempfile::TempDir;
18081810

18091811
#[tokio::test]
@@ -2200,6 +2202,29 @@ mod tests {
22002202
Ok(())
22012203
}
22022204

2205+
#[tokio::test]
2206+
async fn custom_type_planner() -> Result<()> {
2207+
let state = SessionStateBuilder::new()
2208+
.with_default_features()
2209+
.with_type_planner(Arc::new(MyTypePlanner {}))
2210+
.build();
2211+
let ctx = SessionContext::new_with_state(state);
2212+
let result = ctx
2213+
.sql("SELECT DATETIME '2021-01-01 00:00:00'")
2214+
.await?
2215+
.collect()
2216+
.await?;
2217+
let expected = [
2218+
"+-----------------------------+",
2219+
"| Utf8(\"2021-01-01 00:00:00\") |",
2220+
"+-----------------------------+",
2221+
"| 2021-01-01T00:00:00 |",
2222+
"+-----------------------------+",
2223+
];
2224+
assert_batches_eq!(expected, &result);
2225+
Ok(())
2226+
}
2227+
22032228
struct MyPhysicalPlanner {}
22042229

22052230
#[async_trait]
@@ -2260,4 +2285,25 @@ mod tests {
22602285

22612286
Ok(ctx)
22622287
}
2288+
2289+
#[derive(Debug)]
2290+
struct MyTypePlanner {}
2291+
2292+
impl TypePlanner for MyTypePlanner {
2293+
fn plan_type(&self, sql_type: &ast::DataType) -> Result<Option<DataType>> {
2294+
match sql_type {
2295+
ast::DataType::Datetime(precision) => {
2296+
let precision = match precision {
2297+
Some(0) => TimeUnit::Second,
2298+
Some(3) => TimeUnit::Millisecond,
2299+
Some(6) => TimeUnit::Microsecond,
2300+
None | Some(9) => TimeUnit::Nanosecond,
2301+
_ => unreachable!(),
2302+
};
2303+
Ok(Some(DataType::Timestamp(precision, None)))
2304+
}
2305+
_ => Ok(None),
2306+
}
2307+
}
2308+
}
22632309
}

datafusion/core/src/execution/session_state.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ use datafusion_execution::runtime_env::RuntimeEnv;
4848
use datafusion_execution::TaskContext;
4949
use datafusion_expr::execution_props::ExecutionProps;
5050
use datafusion_expr::expr_rewriter::FunctionRewrite;
51-
use datafusion_expr::planner::ExprPlanner;
51+
use datafusion_expr::planner::{ExprPlanner, TypePlanner};
5252
use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry};
5353
use datafusion_expr::simplify::SimplifyInfo;
5454
use datafusion_expr::var_provider::{is_system_variables, VarType};
@@ -128,6 +128,8 @@ pub struct SessionState {
128128
analyzer: Analyzer,
129129
/// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?`
130130
expr_planners: Vec<Arc<dyn ExprPlanner>>,
131+
/// Provides support for customising the SQL type planning
132+
type_planner: Option<Arc<dyn TypePlanner>>,
131133
/// Responsible for optimizing a logical plan
132134
optimizer: Optimizer,
133135
/// Responsible for optimizing a physical execution plan
@@ -192,6 +194,7 @@ impl Debug for SessionState {
192194
.field("table_factories", &self.table_factories)
193195
.field("function_factory", &self.function_factory)
194196
.field("expr_planners", &self.expr_planners)
197+
.field("type_planner", &self.type_planner)
195198
.field("query_planners", &self.query_planner)
196199
.field("analyzer", &self.analyzer)
197200
.field("optimizer", &self.optimizer)
@@ -955,6 +958,7 @@ pub struct SessionStateBuilder {
955958
session_id: Option<String>,
956959
analyzer: Option<Analyzer>,
957960
expr_planners: Option<Vec<Arc<dyn ExprPlanner>>>,
961+
type_planner: Option<Arc<dyn TypePlanner>>,
958962
optimizer: Option<Optimizer>,
959963
physical_optimizers: Option<PhysicalOptimizer>,
960964
query_planner: Option<Arc<dyn QueryPlanner + Send + Sync>>,
@@ -984,6 +988,7 @@ impl SessionStateBuilder {
984988
session_id: None,
985989
analyzer: None,
986990
expr_planners: None,
991+
type_planner: None,
987992
optimizer: None,
988993
physical_optimizers: None,
989994
query_planner: None,
@@ -1031,6 +1036,7 @@ impl SessionStateBuilder {
10311036
session_id: None,
10321037
analyzer: Some(existing.analyzer),
10331038
expr_planners: Some(existing.expr_planners),
1039+
type_planner: existing.type_planner,
10341040
optimizer: Some(existing.optimizer),
10351041
physical_optimizers: Some(existing.physical_optimizers),
10361042
query_planner: Some(existing.query_planner),
@@ -1125,6 +1131,12 @@ impl SessionStateBuilder {
11251131
self
11261132
}
11271133

1134+
/// Set the [`TypePlanner`] used to customize the behavior of the SQL planner.
1135+
pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
1136+
self.type_planner = Some(type_planner);
1137+
self
1138+
}
1139+
11281140
/// Set the [`PhysicalOptimizerRule`]s used to optimize plans.
11291141
pub fn with_physical_optimizer_rules(
11301142
mut self,
@@ -1318,6 +1330,7 @@ impl SessionStateBuilder {
13181330
session_id,
13191331
analyzer,
13201332
expr_planners,
1333+
type_planner,
13211334
optimizer,
13221335
physical_optimizers,
13231336
query_planner,
@@ -1346,6 +1359,7 @@ impl SessionStateBuilder {
13461359
session_id: session_id.unwrap_or(Uuid::new_v4().to_string()),
13471360
analyzer: analyzer.unwrap_or_default(),
13481361
expr_planners: expr_planners.unwrap_or_default(),
1362+
type_planner,
13491363
optimizer: optimizer.unwrap_or_default(),
13501364
physical_optimizers: physical_optimizers.unwrap_or_default(),
13511365
query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})),
@@ -1456,6 +1470,11 @@ impl SessionStateBuilder {
14561470
&mut self.expr_planners
14571471
}
14581472

1473+
/// Returns the current type_planner value
1474+
pub fn type_planner(&mut self) -> &mut Option<Arc<dyn TypePlanner>> {
1475+
&mut self.type_planner
1476+
}
1477+
14591478
/// Returns the current optimizer value
14601479
pub fn optimizer(&mut self) -> &mut Option<Optimizer> {
14611480
&mut self.optimizer
@@ -1578,6 +1597,7 @@ impl Debug for SessionStateBuilder {
15781597
.field("table_factories", &self.table_factories)
15791598
.field("function_factory", &self.function_factory)
15801599
.field("expr_planners", &self.expr_planners)
1600+
.field("type_planner", &self.type_planner)
15811601
.field("query_planners", &self.query_planner)
15821602
.field("analyzer_rules", &self.analyzer_rules)
15831603
.field("analyzer", &self.analyzer)
@@ -1619,6 +1639,14 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
16191639
&self.state.expr_planners
16201640
}
16211641

1642+
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
1643+
if let Some(type_planner) = &self.state.type_planner {
1644+
Some(Arc::clone(type_planner))
1645+
} else {
1646+
None
1647+
}
1648+
}
1649+
16221650
fn get_table_source(
16231651
&self,
16241652
name: TableReference,

datafusion/expr/src/planner.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use datafusion_common::{
2525
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
2626
Result, TableReference,
2727
};
28+
use sqlparser::ast;
2829

2930
use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};
3031

@@ -66,6 +67,11 @@ pub trait ContextProvider {
6667
&[]
6768
}
6869

70+
/// Getter for the data type planner
71+
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
72+
None
73+
}
74+
6975
/// Getter for a UDF description
7076
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>>;
7177
/// Getter for a UDAF description
@@ -216,7 +222,7 @@ pub trait ExprPlanner: Debug + Send + Sync {
216222
/// custom expressions.
217223
#[derive(Debug, Clone)]
218224
pub struct RawBinaryExpr {
219-
pub op: sqlparser::ast::BinaryOperator,
225+
pub op: ast::BinaryOperator,
220226
pub left: Expr,
221227
pub right: Expr,
222228
}
@@ -249,3 +255,13 @@ pub enum PlannerResult<T> {
249255
/// The raw expression could not be planned, and is returned unmodified
250256
Original(T),
251257
}
258+
259+
/// This trait allows users to customize the behavior of the data type planning
260+
pub trait TypePlanner: Debug + Send + Sync {
261+
/// Plan SQL type to DataFusion data type
262+
///
263+
/// Returns None if not possible
264+
fn plan_type(&self, _sql_type: &ast::DataType) -> Result<Option<DataType>> {
265+
Ok(None)
266+
}
267+
}

datafusion/sql/src/planner.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
401401
}
402402

403403
pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result<DataType> {
404+
// First check if any of the registered type_planner can handle this type
405+
if let Some(type_planner) = self.context_provider.get_type_planner() {
406+
if let Some(data_type) = type_planner.plan_type(sql_type)? {
407+
return Ok(data_type);
408+
}
409+
}
410+
411+
// If no type_planner can handle this type, use the default conversion
404412
match sql_type {
405413
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => {
406414
// Arrays may be multi-dimensional.

datafusion/sql/tests/common/mod.rs

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@
1818
use std::any::Any;
1919
#[cfg(test)]
2020
use std::collections::HashMap;
21-
use std::fmt::Display;
21+
use std::fmt::{Debug, Display};
2222
use std::{sync::Arc, vec};
2323

2424
use arrow_schema::*;
2525
use datafusion_common::config::ConfigOptions;
2626
use datafusion_common::file_options::file_type::FileType;
27-
use datafusion_common::{plan_err, GetExt, Result, TableReference};
28-
use datafusion_expr::planner::ExprPlanner;
29-
use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF};
27+
use datafusion_common::{plan_err, DFSchema, GetExt, Result, TableReference};
28+
use datafusion_expr::planner::{ExprPlanner, PlannerResult, TypePlanner};
29+
use datafusion_expr::{AggregateUDF, Expr, ScalarUDF, TableSource, WindowUDF};
30+
use datafusion_functions_nested::expr_fn::make_array;
3031
use datafusion_sql::planner::ContextProvider;
3132

3233
struct MockCsvType {}
@@ -54,6 +55,7 @@ pub(crate) struct MockSessionState {
5455
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
5556
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
5657
expr_planners: Vec<Arc<dyn ExprPlanner>>,
58+
type_planner: Option<Arc<dyn TypePlanner>>,
5759
window_functions: HashMap<String, Arc<WindowUDF>>,
5860
pub config_options: ConfigOptions,
5961
}
@@ -64,6 +66,11 @@ impl MockSessionState {
6466
self
6567
}
6668

69+
pub fn with_type_planner(mut self, type_planner: Arc<dyn TypePlanner>) -> Self {
70+
self.type_planner = Some(type_planner);
71+
self
72+
}
73+
6774
pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
6875
self.scalar_functions
6976
.insert(scalar_function.name().to_string(), scalar_function);
@@ -259,6 +266,14 @@ impl ContextProvider for MockContextProvider {
259266
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
260267
&self.state.expr_planners
261268
}
269+
270+
fn get_type_planner(&self) -> Option<Arc<dyn TypePlanner>> {
271+
if let Some(type_planner) = &self.state.type_planner {
272+
Some(Arc::clone(type_planner))
273+
} else {
274+
None
275+
}
276+
}
262277
}
263278

264279
struct EmptyTable {
@@ -280,3 +295,37 @@ impl TableSource for EmptyTable {
280295
Arc::clone(&self.table_schema)
281296
}
282297
}
298+
299+
#[derive(Debug)]
300+
pub struct CustomTypePlanner {}
301+
302+
impl TypePlanner for CustomTypePlanner {
303+
fn plan_type(&self, sql_type: &sqlparser::ast::DataType) -> Result<Option<DataType>> {
304+
match sql_type {
305+
sqlparser::ast::DataType::Datetime(precision) => {
306+
let precision = match precision {
307+
Some(0) => TimeUnit::Second,
308+
Some(3) => TimeUnit::Millisecond,
309+
Some(6) => TimeUnit::Microsecond,
310+
None | Some(9) => TimeUnit::Nanosecond,
311+
_ => unreachable!(),
312+
};
313+
Ok(Some(DataType::Timestamp(precision, None)))
314+
}
315+
_ => Ok(None),
316+
}
317+
}
318+
}
319+
320+
#[derive(Debug)]
321+
pub struct CustomExprPlanner {}
322+
323+
impl ExprPlanner for CustomExprPlanner {
324+
fn plan_array_literal(
325+
&self,
326+
exprs: Vec<Expr>,
327+
_schema: &DFSchema,
328+
) -> Result<PlannerResult<Vec<Expr>>> {
329+
Ok(PlannerResult::Planned(make_array(exprs)))
330+
}
331+
}

0 commit comments

Comments
 (0)