Skip to content

Commit 5fa90df

Browse files
authored
fix decimal precision issue in simplify expression optimize rule (#15588)
* fix decimal precision * return expr if fail to find the casted type * move fn out * comment * fmt * fix * fix name and doc
1 parent 3b2df6f commit 5fa90df

File tree

2 files changed

+143
-20
lines changed

2 files changed

+143
-20
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 111 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ use datafusion_common::{
3333
};
3434
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
3535
use datafusion_expr::{
36-
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
37-
WindowFunctionDefinition,
36+
and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like,
37+
Operator, Volatility, WindowFunctionDefinition,
3838
};
3939
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
4040
use datafusion_expr::{
@@ -976,30 +976,39 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
976976
// Rules for Multiply
977977
//
978978

979-
// A * 1 --> A
979+
// A * 1 --> A (with type coercion if needed)
980980
Expr::BinaryExpr(BinaryExpr {
981981
left,
982982
op: Multiply,
983983
right,
984-
}) if is_one(&right) => Transformed::yes(*left),
985-
// 1 * A --> A
984+
}) if is_one(&right) => {
985+
simplify_right_is_one_case(info, left, &Multiply, &right)?
986+
}
987+
// A * null --> null
986988
Expr::BinaryExpr(BinaryExpr {
987989
left,
988990
op: Multiply,
989991
right,
990-
}) if is_one(&left) => Transformed::yes(*right),
991-
// A * null --> null
992+
}) if is_null(&right) => {
993+
simplify_right_is_null_case(info, &left, &Multiply, right)?
994+
}
995+
// 1 * A --> A
992996
Expr::BinaryExpr(BinaryExpr {
993-
left: _,
997+
left,
994998
op: Multiply,
995999
right,
996-
}) if is_null(&right) => Transformed::yes(*right),
1000+
}) if is_one(&left) => {
1001+
// 1 * A is equivalent to A * 1
1002+
simplify_right_is_one_case(info, right, &Multiply, &left)?
1003+
}
9971004
// null * A --> null
9981005
Expr::BinaryExpr(BinaryExpr {
9991006
left,
10001007
op: Multiply,
1001-
right: _,
1002-
}) if is_null(&left) => Transformed::yes(*left),
1008+
right,
1009+
}) if is_null(&left) => {
1010+
simplify_right_is_null_case(info, &right, &Multiply, left)?
1011+
}
10031012

10041013
// A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN)
10051014
Expr::BinaryExpr(BinaryExpr {
@@ -1033,19 +1042,23 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
10331042
left,
10341043
op: Divide,
10351044
right,
1036-
}) if is_one(&right) => Transformed::yes(*left),
1037-
// null / A --> null
1045+
}) if is_one(&right) => {
1046+
simplify_right_is_one_case(info, left, &Divide, &right)?
1047+
}
1048+
// A / null --> null
10381049
Expr::BinaryExpr(BinaryExpr {
10391050
left,
10401051
op: Divide,
1041-
right: _,
1042-
}) if is_null(&left) => Transformed::yes(*left),
1043-
// A / null --> null
1052+
right,
1053+
}) if is_null(&right) => {
1054+
simplify_right_is_null_case(info, &left, &Divide, right)?
1055+
}
1056+
// null / A --> null
10441057
Expr::BinaryExpr(BinaryExpr {
1045-
left: _,
1058+
left,
10461059
op: Divide,
10471060
right,
1048-
}) if is_null(&right) => Transformed::yes(*right),
1061+
}) if is_null(&left) => simplify_null_div_other_case(info, left, &right)?,
10491062

10501063
//
10511064
// Rules for Modulo
@@ -1997,6 +2010,84 @@ fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
19972010
}
19982011
}
19992012

2013+
// A * 1 -> A
2014+
// A / 1 -> A
2015+
//
2016+
// Move this function body out of the large match branch avoid stack overflow
2017+
fn simplify_right_is_one_case<S: SimplifyInfo>(
2018+
info: &S,
2019+
left: Box<Expr>,
2020+
op: &Operator,
2021+
right: &Expr,
2022+
) -> Result<Transformed<Expr>> {
2023+
// Check if resulting type would be different due to coercion
2024+
let left_type = info.get_data_type(&left)?;
2025+
let right_type = info.get_data_type(right)?;
2026+
match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2027+
Ok(result_type) => {
2028+
// Only cast if the types differ
2029+
if left_type != result_type {
2030+
Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2031+
} else {
2032+
Ok(Transformed::yes(*left))
2033+
}
2034+
}
2035+
Err(_) => Ok(Transformed::yes(*left)),
2036+
}
2037+
}
2038+
2039+
// A * null -> null
2040+
// A / null -> null
2041+
//
2042+
// Move this function body out of the large match branch avoid stack overflow
2043+
fn simplify_right_is_null_case<S: SimplifyInfo>(
2044+
info: &S,
2045+
left: &Expr,
2046+
op: &Operator,
2047+
right: Box<Expr>,
2048+
) -> Result<Transformed<Expr>> {
2049+
// Check if resulting type would be different due to coercion
2050+
let left_type = info.get_data_type(left)?;
2051+
let right_type = info.get_data_type(&right)?;
2052+
match BinaryTypeCoercer::new(&left_type, op, &right_type).get_result_type() {
2053+
Ok(result_type) => {
2054+
// Only cast if the types differ
2055+
if right_type != result_type {
2056+
Ok(Transformed::yes(Expr::Cast(Cast::new(right, result_type))))
2057+
} else {
2058+
Ok(Transformed::yes(*right))
2059+
}
2060+
}
2061+
Err(_) => Ok(Transformed::yes(*right)),
2062+
}
2063+
}
2064+
2065+
// null / A --> null
2066+
//
2067+
// Move this function body out of the large match branch avoid stack overflow
2068+
fn simplify_null_div_other_case<S: SimplifyInfo>(
2069+
info: &S,
2070+
left: Box<Expr>,
2071+
right: &Expr,
2072+
) -> Result<Transformed<Expr>> {
2073+
// Check if resulting type would be different due to coercion
2074+
let left_type = info.get_data_type(&left)?;
2075+
let right_type = info.get_data_type(right)?;
2076+
match BinaryTypeCoercer::new(&left_type, &Operator::Divide, &right_type)
2077+
.get_result_type()
2078+
{
2079+
Ok(result_type) => {
2080+
// Only cast if the types differ
2081+
if left_type != result_type {
2082+
Ok(Transformed::yes(Expr::Cast(Cast::new(left, result_type))))
2083+
} else {
2084+
Ok(Transformed::yes(*left))
2085+
}
2086+
}
2087+
Err(_) => Ok(Transformed::yes(*left)),
2088+
}
2089+
}
2090+
20002091
#[cfg(test)]
20012092
mod tests {
20022093
use crate::simplify_expressions::SimplifyContext;
@@ -2316,12 +2407,12 @@ mod tests {
23162407
// A / null --> null
23172408
let null = lit(ScalarValue::Null);
23182409
{
2319-
let expr = col("c") / null.clone();
2410+
let expr = col("c1") / null.clone();
23202411
assert_eq!(simplify(expr), null);
23212412
}
23222413
// null / A --> null
23232414
{
2324-
let expr = null.clone() / col("c");
2415+
let expr = null.clone() / col("c1");
23252416
assert_eq!(simplify(expr), null);
23262417
}
23272418
}

datafusion/sqllogictest/test_files/simplify_expr.slt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,35 @@ select b from t where b !~ '.*'
6565

6666
statement ok
6767
drop table t;
68+
69+
# test decimal precision
70+
query B
71+
SELECT a * 1.000::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a);
72+
----
73+
false
74+
75+
query B
76+
SELECT 1.000::DECIMAL(4,3) * a > 1.2::decimal(2,1) FROM VALUES (1) AS t(a);
77+
----
78+
false
79+
80+
query B
81+
SELECT NULL::DECIMAL(4,3) * a > 1.2::decimal(2,1) FROM VALUES (1) AS t(a);
82+
----
83+
NULL
84+
85+
query B
86+
SELECT a * NULL::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a);
87+
----
88+
NULL
89+
90+
query B
91+
SELECT a / 1.000::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a);
92+
----
93+
false
94+
95+
query B
96+
SELECT a / NULL::DECIMAL(4,3) > 1.2::decimal(2,1) FROM VALUES (1) AS t(a);
97+
----
98+
NULL
99+

0 commit comments

Comments
 (0)