Skip to content

Commit 80f95b2

Browse files
Dandandanalamb
andcommitted
Simplify small InListExpr (apache#4090)
* Simplify small InListExpr Simplify small InListExpr * Tweak Tweak * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb <[email protected]> * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb <[email protected]> * Feedback * Feedback * Tweak * Tweak Tweak * Fmt * clippy Co-authored-by: Andrew Lamb <[email protected]>
1 parent 593f00a commit 80f95b2

File tree

5 files changed

+83
-5
lines changed

5 files changed

+83
-5
lines changed

benchmarks/expected-plans/q12.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ Sort: lineitem.l_shipmode ASC NULLS LAST
22
Projection: lineitem.l_shipmode, SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count
33
Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]]
44
Inner Join: lineitem.l_orderkey = orders.o_orderkey
5-
Filter: lineitem.l_shipmode IN ([Utf8("MAIL"), Utf8("SHIP")]) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
5+
Filter: (lineitem.l_shipmode = Utf8("SHIP") OR lineitem.l_shipmode = Utf8("MAIL")) AND lineitem.l_commitdate < lineitem.l_receiptdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131")
66
TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode]
77
TableScan: orders projection=[o_orderkey, o_orderpriority]

benchmarks/expected-plans/q19.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS re
33
Projection: lineitem.l_extendedprice, lineitem.l_discount
44
Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
55
Inner Join: lineitem.l_partkey = part.p_partkey
6-
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND lineitem.l_shipmode IN ([Utf8("AIR"), Utf8("AIR REG")]) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
6+
Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON")
77
TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode]
88
Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
99
TableScan: part projection=[p_partkey, p_brand, p_size, p_container]

datafusion/core/src/physical_plan/planner.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2040,7 +2040,9 @@ mod tests {
20402040
.build()?;
20412041
let execution_plan = plan(&logical_plan).await?;
20422042
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
2043-
let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false }";
2043+
2044+
let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }";
2045+
20442046
let actual = format!("{:?}", execution_plan);
20452047
assert!(actual.contains(expected), "{}", actual);
20462048

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ pub struct ExprSimplifier<S> {
4040
info: S,
4141
}
4242

43+
const THRESHOLD_INLINE_INLIST: usize = 3;
44+
4345
impl<S: SimplifyInfo> ExprSimplifier<S> {
4446
/// Create a new `ExprSimplifier` with the given `info` such as an
4547
/// instance of [`SimplifyContext`]. See
@@ -365,7 +367,48 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
365367
None => lit_bool_null(),
366368
}
367369
}
370+
// expr IN () --> false
371+
// expr NOT IN () --> true
372+
Expr::InList {
373+
expr,
374+
list,
375+
negated,
376+
} if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => {
377+
lit(negated)
378+
}
368379

380+
// if expr is a single column reference:
381+
// expr IN (A, B, ...) --> (expr = A) OR (expr = B) OR (expr = C)
382+
Expr::InList {
383+
expr,
384+
list,
385+
negated,
386+
} if !list.is_empty()
387+
&& (
388+
// For lists with only 1 value we allow more complex expressions to be simplified
389+
// e.g SUBSTR(c1, 2, 3) IN ('1') -> SUBSTR(c1, 2, 3) = '1'
390+
// for more than one we avoid repeating this potentially expensive
391+
// expressions
392+
list.len() == 1
393+
|| list.len() <= THRESHOLD_INLINE_INLIST
394+
&& expr.try_into_col().is_ok()
395+
) =>
396+
{
397+
let first_val = list[0].clone();
398+
if negated {
399+
list.into_iter()
400+
.skip(1)
401+
.fold((*expr.clone()).not_eq(first_val), |acc, y| {
402+
(*expr.clone()).not_eq(y).and(acc)
403+
})
404+
} else {
405+
list.into_iter()
406+
.skip(1)
407+
.fold((*expr.clone()).eq(first_val), |acc, y| {
408+
(*expr.clone()).eq(y).or(acc)
409+
})
410+
}
411+
}
369412
//
370413
// Rules for NotEq
371414
//
@@ -1749,6 +1792,37 @@ mod tests {
17491792
assert_eq!(expected_expr, result);
17501793
}
17511794

1795+
#[test]
1796+
fn simplify_inlist() {
1797+
assert_eq!(simplify(in_list(col("c1"), vec![], false)), lit(false));
1798+
assert_eq!(simplify(in_list(col("c1"), vec![], true)), lit(true));
1799+
1800+
assert_eq!(
1801+
simplify(in_list(col("c1"), vec![lit(1)], false)),
1802+
col("c1").eq(lit(1))
1803+
);
1804+
assert_eq!(
1805+
simplify(in_list(col("c1"), vec![lit(1)], true)),
1806+
col("c1").not_eq(lit(1))
1807+
);
1808+
1809+
// more complex expressions can be simplified if list contains
1810+
// one element only
1811+
assert_eq!(
1812+
simplify(in_list(col("c1") * lit(10), vec![lit(2)], false)),
1813+
(col("c1") * lit(10)).eq(lit(2))
1814+
);
1815+
1816+
assert_eq!(
1817+
simplify(in_list(col("c1"), vec![lit(1), lit(2)], false)),
1818+
col("c1").eq(lit(2)).or(col("c1").eq(lit(1)))
1819+
);
1820+
assert_eq!(
1821+
simplify(in_list(col("c1"), vec![lit(1), lit(2)], true)),
1822+
col("c1").not_eq(lit(2)).and(col("c1").not_eq(lit(1)))
1823+
);
1824+
}
1825+
17521826
#[test]
17531827
fn simplify_expr_bool_and() {
17541828
// col & true is always col

datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,8 @@ mod tests {
706706
.unwrap()
707707
.build()
708708
.unwrap();
709-
let expected = "Filter: test.d NOT IN ([Int32(1), Int32(2), Int32(3)])\
709+
let expected =
710+
"Filter: test.d != Int32(3) AND test.d != Int32(2) AND test.d != Int32(1)\
710711
\n TableScan: test";
711712

712713
assert_optimized_plan_eq(&plan, expected);
@@ -721,7 +722,8 @@ mod tests {
721722
.unwrap()
722723
.build()
723724
.unwrap();
724-
let expected = "Filter: test.d IN ([Int32(1), Int32(2), Int32(3)])\
725+
let expected =
726+
"Filter: test.d = Int32(3) OR test.d = Int32(2) OR test.d = Int32(1)\
725727
\n TableScan: test";
726728

727729
assert_optimized_plan_eq(&plan, expected);

0 commit comments

Comments
 (0)