Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

feat: support left-outer and left-mark hash join impl rules #274

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions optd-datafusion-bridge/src/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ impl OptdPlanContext<'_> {
let right_exec = self.conv_from_optd_plan_node(node.right(), meta).await?;
let join_type = match node.join_type() {
JoinType::Inner => datafusion::logical_expr::JoinType::Inner,
JoinType::LeftOuter => datafusion::logical_expr::JoinType::Left,
JoinType::LeftMark => datafusion::logical_expr::JoinType::LeftMark,
_ => unimplemented!(),
};
let left_exprs = node.left_keys().to_vec();
Expand Down
8 changes: 6 additions & 2 deletions optd-datafusion-repr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ impl DatafusionOptimizer {
rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::JoinInnerSplitFilterRule::new()));
rule_wrappers.push(Arc::new(rules::JoinLeftOuterSplitFilterRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new()));
rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.push(Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down Expand Up @@ -178,7 +182,7 @@ impl DatafusionOptimizer {
for rule in rules {
rule_wrappers.push(rule);
}
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new()));
rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new()));
rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new()));
Expand Down
42 changes: 18 additions & 24 deletions optd-datafusion-repr/src/rules/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,35 @@ pub(crate) fn simplify_log_expr(log_expr: ArcDfPredNode, changed: &mut bool) ->
if let DfPredType::Constant(ConstantType::Bool) = new_child.typ {
let data = ConstantPred::from_pred_node(new_child).unwrap().value();
*changed = true;
// TrueExpr
if data.as_bool() {
if op == LogOpType::And {
// skip True in And
continue;
}
if op == LogOpType::Or {

match (data.as_bool(), op) {
(true, LogOpType::Or) => {
// replace whole exprList with True
return ConstantPred::bool(true).into_pred_node();
}
unreachable!("no other type in logOp");
}
// FalseExpr
if op == LogOpType::And {
// replace whole exprList with False
return ConstantPred::bool(false).into_pred_node();
}
if op == LogOpType::Or {
// skip False in Or
continue;
(false, LogOpType::And) => {
// replace whole exprList with False
return ConstantPred::bool(false).into_pred_node();
}
_ => {
// skip True in `And`, and False in `Or`
continue;
}
}
unreachable!("no other type in logOp");
} else if !new_children_set.contains(&new_child) {
new_children_set.insert(new_child.clone());
new_children.push(new_child);
}
}
if new_children.is_empty() {
if op == LogOpType::And {
return ConstantPred::bool(true).into_pred_node();
}
if op == LogOpType::Or {
return ConstantPred::bool(false).into_pred_node();
match op {
LogOpType::And => {
return ConstantPred::bool(true).into_pred_node();
}
LogOpType::Or => {
return ConstantPred::bool(false).into_pred_node();
}
}
unreachable!("no other type in logOp");
}
if new_children.len() == 1 {
*changed = true;
Expand Down
174 changes: 174 additions & 0 deletions optd-datafusion-repr/src/rules/filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,87 @@ fn apply_filter_merge(
vec![new_filter.into_plan_node().into()]
}

// Rule to split predicates in a join condition into those that can be pushed down as filters.
define_rule!(
JoinInnerSplitFilterRule,
apply_join_split_filter,
(Join(JoinType::Inner), child_a, child_b)
);

define_rule!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this rule is correct. You cannot move the outer join condition into a filter in some cases.

Consider select * from a left join b on a.x = b.y and b.z = 1. The result is different from select * from a left join b on a.x = b.y where b.z = 1. Assume left table is x=1, right table is y=1,z=2, the correct result is 1, NULL, NULL, versus the rule will produce zero rows.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, I realized that this is a filter pushdown, then it might be correct; I will do a review later :)

JoinLeftOuterSplitFilterRule,
apply_join_split_filter,
(Join(JoinType::LeftOuter), child_a, child_b)
);

fn apply_join_split_filter(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
let join = LogicalJoin::from_plan_node(binding.clone()).unwrap();
let left_child = join.left();
let right_child = join.right();
let join_cond = join.cond();
let join_typ = join.join_type();

let left_schema_size = optimizer.get_schema_of(left_child.clone()).len();
let right_schema_size = optimizer.get_schema_of(right_child.clone()).len();

// Conditions that only involve the left relation.
let mut left_conds = vec![];
// Conditions that only involve the right relation.
let mut right_conds = vec![];
// Conditions that involve both relations.
let mut keep_conds = vec![];

let categorization_fn = |expr: ArcDfPredNode, children: &[ArcDfPredNode]| {
let location = determine_join_cond_dep(children, left_schema_size, right_schema_size);
match location {
JoinCondDependency::Left => left_conds.push(expr),
JoinCondDependency::Right => right_conds.push(
expr.rewrite_column_refs(|idx| {
Some(LogicalJoin::map_through_join(
idx,
left_schema_size,
right_schema_size,
))
})
.unwrap(),
),
JoinCondDependency::Both | JoinCondDependency::None => {
// JoinCondDependency::None could happy if there are no column refs in the predicate.
// e.g. true for CrossJoin.
keep_conds.push(expr);
}
}
};
categorize_conds(categorization_fn, join_cond);

let new_left = if !left_conds.is_empty() {
let new_filter_node =
LogicalFilter::new_unchecked(left_child, and_expr_list_to_expr(left_conds));
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
} else {
left_child
};

let new_right = if !right_conds.is_empty() {
let new_filter_node =
LogicalFilter::new_unchecked(right_child, and_expr_list_to_expr(right_conds));
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
} else {
right_child
};

let new_join = LogicalJoin::new_unchecked(
new_left,
new_right,
and_expr_list_to_expr(keep_conds),
*join_typ,
);

vec![new_join.into_plan_node().into()]
}
define_rule!(
FilterInnerJoinTransposeRule,
apply_filter_inner_join_transpose,
Expand Down Expand Up @@ -369,6 +450,8 @@ fn apply_filter_agg_transpose(
mod tests {
use std::sync::Arc;

use optd_core::nodes::Value;

use super::*;
use crate::plan_nodes::{BinOpPred, BinOpType, ConstantPred, LogicalScan};
use crate::testing::new_test_optimizer;
Expand Down Expand Up @@ -442,6 +525,97 @@ mod tests {
assert_eq!(col_4.value().as_i32(), 1);
}

#[test]
fn join_split_filter() {
let mut test_optimizer = new_test_optimizer(Arc::new(JoinLeftOuterSplitFilterRule::new()));

let scan1 = LogicalScan::new("customer".into());

let scan2 = LogicalScan::new("orders".into());

let join_cond = LogOpPred::new(
LogOpType::And,
vec![
BinOpPred::new(
// This one should be pushed to the left child
ColumnRefPred::new(0).into_pred_node(),
ConstantPred::int32(5).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
BinOpPred::new(
// This one should be pushed to the right child
ColumnRefPred::new(11).into_pred_node(),
ConstantPred::int32(6).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
BinOpPred::new(
// This one stays in the join condition.
ColumnRefPred::new(2).into_pred_node(),
ColumnRefPred::new(8).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
// This one stays in the join condition.
ConstantPred::bool(true).into_pred_node(),
],
);

let join = LogicalJoin::new(
scan1.into_plan_node(),
scan2.into_plan_node(),
join_cond.into_pred_node(),
super::JoinType::LeftOuter,
);

let plan = test_optimizer.optimize(join.into_plan_node()).unwrap();
let join = LogicalJoin::from_plan_node(plan.clone()).unwrap();

assert_eq!(join.join_type(), &JoinType::LeftOuter);

{
// Examine join conditions.
let join_conds = LogOpPred::from_pred_node(join.cond()).unwrap();
assert!(matches!(join_conds.op_type(), LogOpType::And));
assert_eq!(join_conds.children().len(), 2);
let bin_op_with_both_ref =
BinOpPred::from_pred_node(join_conds.children()[0].clone()).unwrap();
assert!(matches!(bin_op_with_both_ref.op_type(), BinOpType::Eq));
let col_2 = ColumnRefPred::from_pred_node(bin_op_with_both_ref.left_child()).unwrap();
let col_8 = ColumnRefPred::from_pred_node(bin_op_with_both_ref.right_child()).unwrap();
assert_eq!(col_2.index(), 2);
assert_eq!(col_8.index(), 8);
let constant_true =
ConstantPred::from_pred_node(join_conds.children()[1].clone()).unwrap();
assert_eq!(constant_true.value(), Value::Bool(true));
}

{
// Examine left child filter + condition
let filter_left =
LogicalFilter::from_plan_node(join.left().unwrap_plan_node()).unwrap();
let bin_op = BinOpPred::from_pred_node(filter_left.cond()).unwrap();
assert!(matches!(bin_op.op_type(), BinOpType::Eq));
let col = ColumnRefPred::from_pred_node(bin_op.left_child()).unwrap();
let constant = ConstantPred::from_pred_node(bin_op.right_child()).unwrap();
assert_eq!(col.index(), 0);
assert_eq!(constant.value().as_i32(), 5);
}

{
// Examine right child filter + condition
let filter_right =
LogicalFilter::from_plan_node(join.right().unwrap_plan_node()).unwrap();
let bin_op = BinOpPred::from_pred_node(filter_right.cond()).unwrap();
assert!(matches!(bin_op.op_type(), BinOpType::Eq));
let col = ColumnRefPred::from_pred_node(bin_op.left_child()).unwrap();
let constant = ConstantPred::from_pred_node(bin_op.right_child()).unwrap();
assert_eq!(col.index(), 3);
assert_eq!(constant.value().as_i32(), 6);
}
}

#[test]
fn push_past_join_conjunction() {
// Test pushing a complex filter past a join, where one clause can
Expand Down
19 changes: 16 additions & 3 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,23 @@ fn apply_join_assoc(
}

define_impl_rule!(
HashJoinRule,
HashJoinInnerRule,
apply_hash_join,
(Join(JoinType::Inner), left, right)
);

define_impl_rule!(
HashJoinLeftOuterRule,
apply_hash_join,
(Join(JoinType::LeftOuter), left, right)
);

define_impl_rule!(
HashJoinLeftMarkRule,
apply_hash_join,
(Join(JoinType::LeftMark), left, right)
);

fn apply_hash_join(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
Expand All @@ -154,6 +166,7 @@ fn apply_hash_join(
let cond = join.cond();
let left = join.left();
let right = join.right();
let join_type = join.join_type();
match cond.typ {
DfPredType::BinOp(BinOpType::Eq) => {
let left_schema = optimizer.get_schema_of(left.clone());
Expand Down Expand Up @@ -186,7 +199,7 @@ fn apply_hash_join(
right,
ListPred::new(vec![left_expr.into_pred_node()]),
ListPred::new(vec![right_expr.into_pred_node()]),
JoinType::Inner,
*join_type,
);
return vec![node.into_plan_node().into()];
}
Expand Down Expand Up @@ -244,7 +257,7 @@ fn apply_hash_join(
right,
ListPred::new(left_exprs),
ListPred::new(right_exprs),
JoinType::Inner,
*join_type,
);
return vec![node.into_plan_node().into()];
}
Expand Down
Loading
Loading