Skip to content

Commit 7559c44

Browse files
authored
Improve formatting of binary expressions (#3884)
* Respect operator precedence when writing binary expressions * update tests * update tests * refactor to reduce duplicate code
1 parent 7cba758 commit 7559c44

File tree

8 files changed

+74
-19
lines changed

8 files changed

+74
-19
lines changed

benchmarks/expected-plans/q19.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
33
Projection: lineitem.l_extendedprice, lineitem.l_discount
44
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
55
Inner Join: lineitem.l_partkey = part.p_partkey
6-
Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
6+
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
77
TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
8-
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15) AND part.p_size >= Int32(1)
9-
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]
8+
Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
9+
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]

datafusion/core/src/physical_optimizer/pruning.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,7 @@ mod tests {
14811481
let expr = col("c1")
14821482
.lt(lit(1))
14831483
.and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3))));
1484-
let expected_expr = "c1_min < Int32(1) AND c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max";
1484+
let expected_expr = "c1_min < Int32(1) AND (c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max)";
14851485
let predicate_expr =
14861486
build_predicate_expression(&expr, &schema, &mut required_columns)?;
14871487
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
@@ -1560,7 +1560,9 @@ mod tests {
15601560
list: vec![lit(1), lit(2), lit(3)],
15611561
negated: true,
15621562
};
1563-
let expected_expr = "c1_min != Int32(1) OR Int32(1) != c1_max AND c1_min != Int32(2) OR Int32(2) != c1_max AND c1_min != Int32(3) OR Int32(3) != c1_max";
1563+
let expected_expr = "(c1_min != Int32(1) OR Int32(1) != c1_max) \
1564+
AND (c1_min != Int32(2) OR Int32(2) != c1_max) \
1565+
AND (c1_min != Int32(3) OR Int32(3) != c1_max)";
15641566
let predicate_expr =
15651567
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
15661568
assert_eq!(format!("{:?}", predicate_expr), expected_expr);
@@ -1632,7 +1634,10 @@ mod tests {
16321634
],
16331635
negated: true,
16341636
};
1635-
let expected_expr = "CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64) AND CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64)";
1637+
let expected_expr =
1638+
"(CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64)) \
1639+
AND (CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64)) \
1640+
AND (CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64))";
16361641
let predicate_expr =
16371642
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
16381643
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

datafusion/expr/src/expr.rs

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion_common::Result;
3131
use datafusion_common::{plan_err, Column};
3232
use datafusion_common::{DataFusionError, ScalarValue};
3333
use std::fmt;
34-
use std::fmt::Write;
34+
use std::fmt::{Display, Formatter, Write};
3535
use std::hash::{BuildHasher, Hash, Hasher};
3636
use std::ops::Not;
3737
use std::sync::Arc;
@@ -260,6 +260,58 @@ impl BinaryExpr {
260260
pub fn new(left: Box<Expr>, op: Operator, right: Box<Expr>) -> Self {
261261
Self { left, op, right }
262262
}
263+
264+
/// Get the operator precedence
265+
/// use https://www.postgresql.org/docs/7.0/operators.htm#AEN2026 as a reference
266+
pub fn precedence(&self) -> u8 {
267+
match self.op {
268+
Operator::Or => 5,
269+
Operator::And => 10,
270+
Operator::Like | Operator::NotLike => 19,
271+
Operator::NotEq
272+
| Operator::Eq
273+
| Operator::Lt
274+
| Operator::LtEq
275+
| Operator::Gt
276+
| Operator::GtEq => 20,
277+
Operator::Plus | Operator::Minus => 30,
278+
Operator::Multiply | Operator::Divide | Operator::Modulo => 40,
279+
_ => 0,
280+
}
281+
}
282+
}
283+
284+
impl Display for BinaryExpr {
285+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
286+
// Put parentheses around child binary expressions so that we can see the difference
287+
// between `(a OR b) AND c` and `a OR (b AND c)`. We only insert parentheses when needed,
288+
// based on operator precedence. For example, `(a AND b) OR c` and `a AND b OR c` are
289+
// equivalent and the parentheses are not necessary.
290+
291+
fn write_child(
292+
f: &mut Formatter<'_>,
293+
expr: &Expr,
294+
precedence: u8,
295+
) -> fmt::Result {
296+
match expr {
297+
Expr::BinaryExpr(child) => {
298+
let p = child.precedence();
299+
if p == 0 || p < precedence {
300+
write!(f, "({})", child)?;
301+
} else {
302+
write!(f, "{}", child)?;
303+
}
304+
}
305+
_ => write!(f, "{}", expr)?,
306+
}
307+
Ok(())
308+
}
309+
310+
let precedence = self.precedence();
311+
write_child(f, self.left.as_ref(), precedence)?;
312+
write!(f, " {} ", self.op)?;
313+
write_child(f, self.right.as_ref(), precedence)
314+
}
263315
}
264316

265317
/// CASE expression
@@ -728,9 +780,7 @@ impl fmt::Debug for Expr {
728780
negated: false,
729781
} => write!(f, "{:?} IN ({:?})", expr, subquery),
730782
Expr::ScalarSubquery(subquery) => write!(f, "({:?})", subquery),
731-
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
732-
write!(f, "{:?} {} {:?}", left, op, right)
733-
}
783+
Expr::BinaryExpr(expr) => write!(f, "{}", expr),
734784
Expr::Sort {
735785
expr,
736786
asc,

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ mod test {
619619
)?;
620620

621621
let expected = vec![
622-
(9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
622+
(9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
623623
(7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
624624
(4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
625625
(3, "a + Int32(1)Int32(1)a"),
@@ -671,8 +671,8 @@ mod test {
671671
)?
672672
.build()?;
673673

674-
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * Int32(1) + test.c)]]\
675-
\n Projection: test.a * Int32(1) - test.b AS test.a * Int32(1) - test.bInt32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
674+
let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\
675+
\n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
676676
\n TableScan: test";
677677

678678
assert_optimized_plan_eq(expected, &plan);

datafusion/optimizer/src/filter_push_down.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1044,7 +1044,7 @@ mod tests {
10441044
let expected = "\
10451045
Projection: b * Int32(3) AS a, test.c\
10461046
\n Projection: test.a * Int32(2) + test.c AS b, test.c\
1047-
\n Filter: test.a * Int32(2) + test.c * Int32(3) = Int64(1)\
1047+
\n Filter: (test.a * Int32(2) + test.c) * Int32(3) = Int64(1)\
10481048
\n TableScan: test";
10491049
assert_optimized_plan_eq(&plan, expected);
10501050
Ok(())

datafusion/optimizer/src/reduce_cross_join.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ mod tests {
848848
.build()?;
849849

850850
let expected = vec![
851-
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
851+
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
852852
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
853853
" Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
854854
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
@@ -936,7 +936,7 @@ mod tests {
936936
.build()?;
937937

938938
let expected = vec![
939-
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) AND t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
939+
"Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
940940
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
941941
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
942942
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",

datafusion/optimizer/src/subquery_filter_to_join.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ mod tests {
352352
.build()?;
353353

354354
let expected = "Projection: test.b [b:UInt32]\
355-
\n Filter: test.a = UInt32(1) OR test.b IN (<subquery>) AND test.c IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
355+
\n Filter: (test.a = UInt32(1) OR test.b IN (<subquery>)) AND test.c IN (<subquery>) [a:UInt32, b:UInt32, c:UInt32]\
356356
\n Subquery: [c:UInt32]\
357357
\n Projection: sq1.c [c:UInt32]\
358358
\n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\

datafusion/sql/src/planner.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,7 +3414,7 @@ mod tests {
34143414
#[test]
34153415
fn select_binary_expr_nested() {
34163416
let sql = "SELECT (age + salary)/2 from person";
3417-
let expected = "Projection: person.age + person.salary / Int64(2)\
3417+
let expected = "Projection: (person.age + person.salary) / Int64(2)\
34183418
\n TableScan: person";
34193419
quick_test(sql, expected);
34203420
}
@@ -3849,7 +3849,7 @@ mod tests {
38493849
fn select_where_nullif_division() {
38503850
let sql = "SELECT c3/(c4+c5) \
38513851
FROM aggregate_test_100 WHERE c3/nullif(c4+c5, 0) > 0.1";
3852-
let expected = "Projection: aggregate_test_100.c3 / aggregate_test_100.c4 + aggregate_test_100.c5\
3852+
let expected = "Projection: aggregate_test_100.c3 / (aggregate_test_100.c4 + aggregate_test_100.c5)\
38533853
\n Filter: aggregate_test_100.c3 / nullif(aggregate_test_100.c4 + aggregate_test_100.c5, Int64(0)) > Float64(0.1)\
38543854
\n TableScan: aggregate_test_100";
38553855
quick_test(sql, expected);

0 commit comments

Comments
 (0)