Skip to content

Commit 9b492c6

Browse files
authored
Improve round scalar function unparsing for Postgres (#12744)
* Postgres: enforce required `NUMERIC` type for `round` scalar function (#34) Includes initial support for dialects to override scalar functions unparsing * Document scalar_function_to_sql_overrides fn
1 parent ecb0044 commit 9b492c6

File tree

3 files changed

+273
-126
lines changed

3 files changed

+273
-126
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,17 @@
1818
use std::sync::Arc;
1919

2020
use arrow_schema::TimeUnit;
21+
use datafusion_expr::Expr;
2122
use regex::Regex;
2223
use sqlparser::{
23-
ast::{self, Ident, ObjectName, TimezoneInfo},
24+
ast::{self, Function, Ident, ObjectName, TimezoneInfo},
2425
keywords::ALL_KEYWORDS,
2526
};
2627

28+
use datafusion_common::Result;
29+
30+
use super::{utils::date_part_to_sql, Unparser};
31+
2732
/// `Dialect` to use for Unparsing
2833
///
2934
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
@@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync {
108113
fn supports_column_alias_in_table_alias(&self) -> bool {
109114
true
110115
}
116+
117+
/// Allows the dialect to override scalar function unparsing if the dialect has specific rules.
118+
/// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is
119+
/// a custom implementation for the function.
120+
fn scalar_function_to_sql_overrides(
121+
&self,
122+
_unparser: &Unparser,
123+
_func_name: &str,
124+
_args: &[Expr],
125+
) -> Result<Option<ast::Expr>> {
126+
Ok(None)
127+
}
111128
}
112129

113130
/// `IntervalStyle` to use for unparsing
@@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect {
171188
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
172189
sqlparser::ast::DataType::DoublePrecision
173190
}
191+
192+
fn scalar_function_to_sql_overrides(
193+
&self,
194+
unparser: &Unparser,
195+
func_name: &str,
196+
args: &[Expr],
197+
) -> Result<Option<ast::Expr>> {
198+
if func_name == "round" {
199+
return Ok(Some(
200+
self.round_to_sql_enforce_numeric(unparser, func_name, args)?,
201+
));
202+
}
203+
204+
Ok(None)
205+
}
206+
}
207+
208+
impl PostgreSqlDialect {
209+
fn round_to_sql_enforce_numeric(
210+
&self,
211+
unparser: &Unparser,
212+
func_name: &str,
213+
args: &[Expr],
214+
) -> Result<ast::Expr> {
215+
let mut args = unparser.function_args_to_sql(args)?;
216+
217+
// Enforce the first argument to be Numeric
218+
if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) =
219+
args.first_mut()
220+
{
221+
if let ast::Expr::Cast { data_type, .. } = expr {
222+
// Don't create an additional cast wrapper if we can update the existing one
223+
*data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None);
224+
} else {
225+
// Wrap the expression in a new cast
226+
*expr = ast::Expr::Cast {
227+
kind: ast::CastKind::Cast,
228+
expr: Box::new(expr.clone()),
229+
data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None),
230+
format: None,
231+
};
232+
}
233+
}
234+
235+
Ok(ast::Expr::Function(Function {
236+
name: ast::ObjectName(vec![Ident {
237+
value: func_name.to_string(),
238+
quote_style: None,
239+
}]),
240+
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
241+
duplicate_treatment: None,
242+
args,
243+
clauses: vec![],
244+
}),
245+
filter: None,
246+
null_treatment: None,
247+
over: None,
248+
within_group: vec![],
249+
parameters: ast::FunctionArguments::None,
250+
}))
251+
}
174252
}
175253

176254
pub struct MySqlDialect {}
@@ -211,6 +289,19 @@ impl Dialect for MySqlDialect {
211289
) -> ast::DataType {
212290
ast::DataType::Datetime(None)
213291
}
292+
293+
fn scalar_function_to_sql_overrides(
294+
&self,
295+
unparser: &Unparser,
296+
func_name: &str,
297+
args: &[Expr],
298+
) -> Result<Option<ast::Expr>> {
299+
if func_name == "date_part" {
300+
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
301+
}
302+
303+
Ok(None)
304+
}
214305
}
215306

216307
pub struct SqliteDialect {}
@@ -231,6 +322,19 @@ impl Dialect for SqliteDialect {
231322
fn supports_column_alias_in_table_alias(&self) -> bool {
232323
false
233324
}
325+
326+
fn scalar_function_to_sql_overrides(
327+
&self,
328+
unparser: &Unparser,
329+
func_name: &str,
330+
args: &[Expr],
331+
) -> Result<Option<ast::Expr>> {
332+
if func_name == "date_part" {
333+
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
334+
}
335+
336+
Ok(None)
337+
}
234338
}
235339

236340
pub struct CustomDialect {
@@ -339,6 +443,19 @@ impl Dialect for CustomDialect {
339443
fn supports_column_alias_in_table_alias(&self) -> bool {
340444
self.supports_column_alias_in_table_alias
341445
}
446+
447+
fn scalar_function_to_sql_overrides(
448+
&self,
449+
unparser: &Unparser,
450+
func_name: &str,
451+
args: &[Expr],
452+
) -> Result<Option<ast::Expr>> {
453+
if func_name == "date_part" {
454+
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
455+
}
456+
457+
Ok(None)
458+
}
342459
}
343460

344461
/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern

0 commit comments

Comments
 (0)