18
18
use std:: sync:: Arc ;
19
19
20
20
use arrow_schema:: TimeUnit ;
21
+ use datafusion_expr:: Expr ;
21
22
use regex:: Regex ;
22
23
use sqlparser:: {
23
- ast:: { self , Ident , ObjectName , TimezoneInfo } ,
24
+ ast:: { self , Function , Ident , ObjectName , TimezoneInfo } ,
24
25
keywords:: ALL_KEYWORDS ,
25
26
} ;
26
27
28
+ use datafusion_common:: Result ;
29
+
30
+ use super :: { utils:: date_part_to_sql, Unparser } ;
31
+
27
32
/// `Dialect` to use for Unparsing
28
33
///
29
34
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
@@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync {
108
113
fn supports_column_alias_in_table_alias ( & self ) -> bool {
109
114
true
110
115
}
116
+
117
+ /// Allows the dialect to override scalar function unparsing if the dialect has specific rules.
118
+ /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is
119
+ /// a custom implementation for the function.
120
+ fn scalar_function_to_sql_overrides (
121
+ & self ,
122
+ _unparser : & Unparser ,
123
+ _func_name : & str ,
124
+ _args : & [ Expr ] ,
125
+ ) -> Result < Option < ast:: Expr > > {
126
+ Ok ( None )
127
+ }
111
128
}
112
129
113
130
/// `IntervalStyle` to use for unparsing
@@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect {
171
188
fn float64_ast_dtype ( & self ) -> sqlparser:: ast:: DataType {
172
189
sqlparser:: ast:: DataType :: DoublePrecision
173
190
}
191
+
192
+ fn scalar_function_to_sql_overrides (
193
+ & self ,
194
+ unparser : & Unparser ,
195
+ func_name : & str ,
196
+ args : & [ Expr ] ,
197
+ ) -> Result < Option < ast:: Expr > > {
198
+ if func_name == "round" {
199
+ return Ok ( Some (
200
+ self . round_to_sql_enforce_numeric ( unparser, func_name, args) ?,
201
+ ) ) ;
202
+ }
203
+
204
+ Ok ( None )
205
+ }
206
+ }
207
+
208
+ impl PostgreSqlDialect {
209
+ fn round_to_sql_enforce_numeric (
210
+ & self ,
211
+ unparser : & Unparser ,
212
+ func_name : & str ,
213
+ args : & [ Expr ] ,
214
+ ) -> Result < ast:: Expr > {
215
+ let mut args = unparser. function_args_to_sql ( args) ?;
216
+
217
+ // Enforce the first argument to be Numeric
218
+ if let Some ( ast:: FunctionArg :: Unnamed ( ast:: FunctionArgExpr :: Expr ( expr) ) ) =
219
+ args. first_mut ( )
220
+ {
221
+ if let ast:: Expr :: Cast { data_type, .. } = expr {
222
+ // Don't create an additional cast wrapper if we can update the existing one
223
+ * data_type = ast:: DataType :: Numeric ( ast:: ExactNumberInfo :: None ) ;
224
+ } else {
225
+ // Wrap the expression in a new cast
226
+ * expr = ast:: Expr :: Cast {
227
+ kind : ast:: CastKind :: Cast ,
228
+ expr : Box :: new ( expr. clone ( ) ) ,
229
+ data_type : ast:: DataType :: Numeric ( ast:: ExactNumberInfo :: None ) ,
230
+ format : None ,
231
+ } ;
232
+ }
233
+ }
234
+
235
+ Ok ( ast:: Expr :: Function ( Function {
236
+ name : ast:: ObjectName ( vec ! [ Ident {
237
+ value: func_name. to_string( ) ,
238
+ quote_style: None ,
239
+ } ] ) ,
240
+ args : ast:: FunctionArguments :: List ( ast:: FunctionArgumentList {
241
+ duplicate_treatment : None ,
242
+ args,
243
+ clauses : vec ! [ ] ,
244
+ } ) ,
245
+ filter : None ,
246
+ null_treatment : None ,
247
+ over : None ,
248
+ within_group : vec ! [ ] ,
249
+ parameters : ast:: FunctionArguments :: None ,
250
+ } ) )
251
+ }
174
252
}
175
253
176
254
pub struct MySqlDialect { }
@@ -211,6 +289,19 @@ impl Dialect for MySqlDialect {
211
289
) -> ast:: DataType {
212
290
ast:: DataType :: Datetime ( None )
213
291
}
292
+
293
+ fn scalar_function_to_sql_overrides (
294
+ & self ,
295
+ unparser : & Unparser ,
296
+ func_name : & str ,
297
+ args : & [ Expr ] ,
298
+ ) -> Result < Option < ast:: Expr > > {
299
+ if func_name == "date_part" {
300
+ return date_part_to_sql ( unparser, self . date_field_extract_style ( ) , args) ;
301
+ }
302
+
303
+ Ok ( None )
304
+ }
214
305
}
215
306
216
307
pub struct SqliteDialect { }
@@ -231,6 +322,19 @@ impl Dialect for SqliteDialect {
231
322
fn supports_column_alias_in_table_alias ( & self ) -> bool {
232
323
false
233
324
}
325
+
326
+ fn scalar_function_to_sql_overrides (
327
+ & self ,
328
+ unparser : & Unparser ,
329
+ func_name : & str ,
330
+ args : & [ Expr ] ,
331
+ ) -> Result < Option < ast:: Expr > > {
332
+ if func_name == "date_part" {
333
+ return date_part_to_sql ( unparser, self . date_field_extract_style ( ) , args) ;
334
+ }
335
+
336
+ Ok ( None )
337
+ }
234
338
}
235
339
236
340
pub struct CustomDialect {
@@ -339,6 +443,19 @@ impl Dialect for CustomDialect {
339
443
fn supports_column_alias_in_table_alias ( & self ) -> bool {
340
444
self . supports_column_alias_in_table_alias
341
445
}
446
+
447
+ fn scalar_function_to_sql_overrides (
448
+ & self ,
449
+ unparser : & Unparser ,
450
+ func_name : & str ,
451
+ args : & [ Expr ] ,
452
+ ) -> Result < Option < ast:: Expr > > {
453
+ if func_name == "date_part" {
454
+ return date_part_to_sql ( unparser, self . date_field_extract_style ( ) , args) ;
455
+ }
456
+
457
+ Ok ( None )
458
+ }
342
459
}
343
460
344
461
/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
0 commit comments