Skip to content

Commit 827d0e3

Browse files
authored
Add dialect param to use double precision for float64 in Postgres (#11495)
* Add dialect param to use double precision for float64 in Postgres * return ast data type instead of bool * Fix errors in merging * fix
1 parent ebe61ba commit 827d0e3

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,18 @@ pub trait Dialect: Send + Sync {
4646
IntervalStyle::PostgresVerbose
4747
}
4848

49+
// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE?
50+
// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE
51+
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
52+
sqlparser::ast::DataType::Double
53+
}
54+
4955
// The SQL type to use for Arrow Utf8 unparsing
5056
// Most dialects use VARCHAR, but some, like MySQL, require CHAR
5157
fn utf8_cast_dtype(&self) -> ast::DataType {
5258
ast::DataType::Varchar(None)
5359
}
60+
5461
// The SQL type to use for Arrow LargeUtf8 unparsing
5562
// Most dialects use TEXT, but some, like MySQL, require CHAR
5663
fn large_utf8_cast_dtype(&self) -> ast::DataType {
@@ -98,6 +105,10 @@ impl Dialect for PostgreSqlDialect {
98105
fn interval_style(&self) -> IntervalStyle {
99106
IntervalStyle::PostgresVerbose
100107
}
108+
109+
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
110+
sqlparser::ast::DataType::DoublePrecision
111+
}
101112
}
102113

103114
pub struct MySqlDialect {}
@@ -137,6 +148,7 @@ pub struct CustomDialect {
137148
supports_nulls_first_in_sort: bool,
138149
use_timestamp_for_date64: bool,
139150
interval_style: IntervalStyle,
151+
float64_ast_dtype: sqlparser::ast::DataType,
140152
utf8_cast_dtype: ast::DataType,
141153
large_utf8_cast_dtype: ast::DataType,
142154
}
@@ -148,6 +160,7 @@ impl Default for CustomDialect {
148160
supports_nulls_first_in_sort: true,
149161
use_timestamp_for_date64: false,
150162
interval_style: IntervalStyle::SQLStandard,
163+
float64_ast_dtype: sqlparser::ast::DataType::Double,
151164
utf8_cast_dtype: ast::DataType::Varchar(None),
152165
large_utf8_cast_dtype: ast::DataType::Text,
153166
}
@@ -182,6 +195,10 @@ impl Dialect for CustomDialect {
182195
self.interval_style
183196
}
184197

198+
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
199+
self.float64_ast_dtype.clone()
200+
}
201+
185202
fn utf8_cast_dtype(&self) -> ast::DataType {
186203
self.utf8_cast_dtype.clone()
187204
}
@@ -210,6 +227,7 @@ pub struct CustomDialectBuilder {
210227
supports_nulls_first_in_sort: bool,
211228
use_timestamp_for_date64: bool,
212229
interval_style: IntervalStyle,
230+
float64_ast_dtype: sqlparser::ast::DataType,
213231
utf8_cast_dtype: ast::DataType,
214232
large_utf8_cast_dtype: ast::DataType,
215233
}
@@ -227,6 +245,7 @@ impl CustomDialectBuilder {
227245
supports_nulls_first_in_sort: true,
228246
use_timestamp_for_date64: false,
229247
interval_style: IntervalStyle::PostgresVerbose,
248+
float64_ast_dtype: sqlparser::ast::DataType::Double,
230249
utf8_cast_dtype: ast::DataType::Varchar(None),
231250
large_utf8_cast_dtype: ast::DataType::Text,
232251
}
@@ -238,6 +257,7 @@ impl CustomDialectBuilder {
238257
supports_nulls_first_in_sort: self.supports_nulls_first_in_sort,
239258
use_timestamp_for_date64: self.use_timestamp_for_date64,
240259
interval_style: self.interval_style,
260+
float64_ast_dtype: self.float64_ast_dtype,
241261
utf8_cast_dtype: self.utf8_cast_dtype,
242262
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
243263
}
@@ -273,6 +293,14 @@ impl CustomDialectBuilder {
273293
self
274294
}
275295

296+
pub fn with_float64_ast_dtype(
297+
mut self,
298+
float64_ast_dtype: sqlparser::ast::DataType,
299+
) -> Self {
300+
self.float64_ast_dtype = float64_ast_dtype;
301+
self
302+
}
303+
276304
pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self {
277305
self.utf8_cast_dtype = utf8_cast_dtype;
278306
self

datafusion/sql/src/unparser/expr.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,7 @@ impl Unparser<'_> {
12401240
not_impl_err!("Unsupported DataType: conversion: {data_type:?}")
12411241
}
12421242
DataType::Float32 => Ok(ast::DataType::Float(None)),
1243-
DataType::Float64 => Ok(ast::DataType::Double),
1243+
DataType::Float64 => Ok(self.dialect.float64_ast_dtype()),
12441244
DataType::Timestamp(_, tz) => {
12451245
let tz_info = match tz {
12461246
Some(_) => TimezoneInfo::WithTimeZone,
@@ -1822,6 +1822,34 @@ mod tests {
18221822
Ok(())
18231823
}
18241824

1825+
#[test]
1826+
fn custom_dialect_float64_ast_dtype() -> Result<()> {
1827+
for (float64_ast_dtype, identifier) in [
1828+
(sqlparser::ast::DataType::Double, "DOUBLE"),
1829+
(
1830+
sqlparser::ast::DataType::DoublePrecision,
1831+
"DOUBLE PRECISION",
1832+
),
1833+
] {
1834+
let dialect = CustomDialectBuilder::new()
1835+
.with_float64_ast_dtype(float64_ast_dtype)
1836+
.build();
1837+
let unparser = Unparser::new(&dialect);
1838+
1839+
let expr = Expr::Cast(Cast {
1840+
expr: Box::new(col("a")),
1841+
data_type: DataType::Float64,
1842+
});
1843+
let ast = unparser.expr_to_sql(&expr)?;
1844+
1845+
let actual = format!("{}", ast);
1846+
1847+
let expected = format!(r#"CAST(a AS {identifier})"#);
1848+
assert_eq!(actual, expected);
1849+
}
1850+
Ok(())
1851+
}
1852+
18251853
#[test]
18261854
fn customer_dialect_support_nulls_first_in_ort() -> Result<()> {
18271855
let tests: Vec<(Expr, &str, bool)> = vec![

0 commit comments

Comments
 (0)