Skip to content

Commit 6b71294

Browse files
authored
Linearize binary expressions to reduce proto tree complexity (#4115)
1 parent 3892a1f commit 6b71294

File tree

6 files changed

+168
-46
lines changed

6 files changed

+168
-46
lines changed

datafusion/proto/proto/datafusion.proto

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,10 @@ message AliasNode {
409409
}
410410

411411
message BinaryExprNode {
412-
LogicalExprNode l = 1;
413-
LogicalExprNode r = 2;
412+
// Represents the operands from the left inner most expression
413+
// to the right outer most expression where each of them are chained
414+
// with the operator 'op'.
415+
repeated LogicalExprNode operands = 1;
414416
string op = 3;
415417
}
416418

datafusion/proto/src/bytes/mod.rs

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,98 @@ mod test {
321321
Expr::from_bytes(&bytes).unwrap();
322322
}
323323

324+
fn roundtrip_expr(expr: &Expr) -> Expr {
325+
let bytes = expr.to_bytes().unwrap();
326+
Expr::from_bytes(&bytes).unwrap()
327+
}
328+
329+
#[test]
330+
fn exact_roundtrip_linearized_binary_expr() {
331+
// (((A AND B) AND C) AND D)
332+
let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D"));
333+
assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered));
334+
335+
// Ensure that no other variation becomes equal
336+
let other_variants = vec![
337+
// (((B AND A) AND C) AND D)
338+
col("B").and(col("A")).and(col("C")).and(col("D")),
339+
// (((A AND C) AND B) AND D)
340+
col("A").and(col("C")).and(col("B")).and(col("D")),
341+
// (((A AND B) AND D) AND C)
342+
col("A").and(col("B")).and(col("D")).and(col("C")),
343+
// A AND (B AND (C AND D)))
344+
col("A").and(col("B").and(col("C").and(col("D")))),
345+
];
346+
for case in other_variants {
347+
// Each variant is still equal to itself
348+
assert_eq!(case, roundtrip_expr(&case));
349+
350+
// But non of them is equal to the original
351+
assert_ne!(expr_ordered, roundtrip_expr(&case));
352+
assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case));
353+
}
354+
}
355+
356+
#[test]
357+
fn roundtrip_deeply_nested_binary_expr() {
358+
// We need more stack space so this doesn't overflow in dev builds
359+
std::thread::Builder::new()
360+
.stack_size(10_000_000)
361+
.spawn(|| {
362+
let n = 100;
363+
// a < 5
364+
let basic_expr = col("a").lt(lit(5i32));
365+
// (a < 5) OR (a < 5) OR (a < 5) OR ...
366+
let or_chain = (0..n)
367+
.fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone()));
368+
// (a < 5) OR (a < 5) AND (a < 5) OR (a < 5) AND (a < 5) AND (a < 5) OR ...
369+
let expr =
370+
(0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone()));
371+
372+
// Should work fine.
373+
let bytes = expr.to_bytes().unwrap();
374+
375+
let decoded_expr = Expr::from_bytes(&bytes).expect(
376+
"serialization worked, so deserialization should work as well",
377+
);
378+
assert_eq!(decoded_expr, expr);
379+
})
380+
.expect("spawning thread")
381+
.join()
382+
.expect("joining thread");
383+
}
384+
385+
#[test]
386+
fn roundtrip_deeply_nested_binary_expr_reverse_order() {
387+
// We need more stack space so this doesn't overflow in dev builds
388+
std::thread::Builder::new()
389+
.stack_size(10_000_000)
390+
.spawn(|| {
391+
let n = 100;
392+
393+
// a < 5
394+
let expr_base = col("a").lt(lit(5i32));
395+
396+
// ((a < 5 AND a < 5) AND a < 5) AND ...
397+
let and_chain =
398+
(0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
399+
400+
// a < 5 AND (a < 5 AND (a < 5 AND ...))
401+
let expr = expr_base.and(and_chain);
402+
403+
// Should work fine.
404+
let bytes = expr.to_bytes().unwrap();
405+
406+
let decoded_expr = Expr::from_bytes(&bytes).expect(
407+
"serialization worked, so deserialization should work as well",
408+
);
409+
assert_eq!(decoded_expr, expr);
410+
})
411+
.expect("spawning thread")
412+
.join()
413+
.expect("joining thread");
414+
}
415+
324416
#[test]
325417
fn roundtrip_deeply_nested() {
326418
// we need more stack space so this doesn't overflow in dev builds
@@ -332,7 +424,8 @@ mod test {
332424
println!("testing: {n}");
333425

334426
let expr_base = col("a").lt(lit(5i32));
335-
let expr = (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
427+
// Generate a tree of AND and OR expressions (no subsequent ANDs or ORs).
428+
let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) });
336429

337430
// Convert it to an opaque form
338431
let bytes = match expr.to_bytes() {

datafusion/proto/src/from_proto.rs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -690,11 +690,29 @@ pub fn parse_expr(
690690
.ok_or_else(|| Error::required("expr_type"))?;
691691

692692
match expr_type {
693-
ExprType::BinaryExpr(binary_expr) => Ok(Expr::BinaryExpr(BinaryExpr::new(
694-
Box::new(parse_required_expr(&binary_expr.l, registry, "l")?),
695-
from_proto_binary_op(&binary_expr.op)?,
696-
Box::new(parse_required_expr(&binary_expr.r, registry, "r")?),
697-
))),
693+
ExprType::BinaryExpr(binary_expr) => {
694+
let op = from_proto_binary_op(&binary_expr.op)?;
695+
let operands = binary_expr
696+
.operands
697+
.iter()
698+
.map(|expr| parse_expr(expr, registry))
699+
.collect::<Result<Vec<_>, _>>()?;
700+
701+
if operands.len() < 2 {
702+
return Err(proto_error(
703+
"A binary expression must always have at least 2 operands",
704+
));
705+
}
706+
707+
// Reduce the linearized operands (ordered by left innermost to right
708+
// outermost) into a single expression tree.
709+
Ok(operands
710+
.into_iter()
711+
.reduce(|left, right| {
712+
Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
713+
})
714+
.expect("Binary expression could not be reduced to a single expression."))
715+
}
698716
ExprType::GetIndexedField(field) => {
699717
let key = field
700718
.key

datafusion/proto/src/generated/pbjson.rs

Lines changed: 12 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 6 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/to_proto.rs

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,36 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
455455
}
456456
}
457457
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
458-
let binary_expr = Box::new(protobuf::BinaryExprNode {
459-
l: Some(Box::new(left.as_ref().try_into()?)),
460-
r: Some(Box::new(right.as_ref().try_into()?)),
458+
// Try to linerize a nested binary expression tree of the same operator
459+
// into a flat vector of expressions.
460+
let mut exprs = vec![right.as_ref()];
461+
let mut current_expr = left.as_ref();
462+
while let Expr::BinaryExpr(BinaryExpr {
463+
left,
464+
op: current_op,
465+
right,
466+
}) = current_expr
467+
{
468+
if current_op == op {
469+
exprs.push(right.as_ref());
470+
current_expr = left.as_ref();
471+
} else {
472+
break;
473+
}
474+
}
475+
exprs.push(current_expr);
476+
477+
let binary_expr = protobuf::BinaryExprNode {
478+
// We need to reverse exprs since operands are expected to be
479+
// linearized from left innermost to right outermost (but while
480+
// traversing the chain we do the exact opposite).
481+
operands: exprs
482+
.into_iter()
483+
.rev()
484+
.map(|expr| expr.try_into())
485+
.collect::<Result<Vec<_>, Error>>()?,
461486
op: format!("{:?}", op),
462-
});
487+
};
463488
Self {
464489
expr_type: Some(ExprType::BinaryExpr(binary_expr)),
465490
}

0 commit comments

Comments
 (0)