@@ -24,7 +24,7 @@ use optd_core::optimizer::Optimizer;
2424use optd_core:: rules:: { Rule , RuleMatcher } ;
2525
2626use super :: filter:: simplify_log_expr;
27- use super :: macros:: define_rule;
27+ use super :: macros:: { define_rule, define_rule_discriminant } ;
2828use crate :: plan_nodes:: {
2929 ArcDfPlanNode , ArcDfPredNode , ColumnRefPred , DfNodeType , DfPredType , DfReprPlanNode ,
3030 DfReprPredNode , JoinType , ListPred , LogOpPred , LogOpType , LogicalAgg , LogicalFilter ,
@@ -160,6 +160,87 @@ fn apply_filter_merge(
160160 vec ! [ new_filter. into_plan_node( ) . into( ) ]
161161}
162162
163+ // Rule to split predicates in a join condition into those that can be pushed down as filters.
164+ define_rule ! (
165+ InnerJoinSplitFilterRule ,
166+ apply_join_split_filter,
167+ ( Join ( JoinType :: Inner ) , child_a, child_b)
168+ ) ;
169+
170+ define_rule ! (
171+ LeftOuterJoinSplitFilterRule ,
172+ apply_join_split_filter,
173+ ( Join ( JoinType :: LeftOuter ) , child_a, child_b)
174+ ) ;
175+
176+ fn apply_join_split_filter (
177+ optimizer : & impl Optimizer < DfNodeType > ,
178+ binding : ArcDfPlanNode ,
179+ ) -> Vec < PlanNodeOrGroup < DfNodeType > > {
180+ println ! ( "Applying JoinSplitFilterRule" ) ;
181+ let join = LogicalJoin :: from_plan_node ( binding) . unwrap ( ) ;
182+ let left_child = join. left ( ) ;
183+ let right_child = join. right ( ) ;
184+ let join_cond = join. cond ( ) ;
185+ let join_typ = join. join_type ( ) ;
186+
187+ let left_schema_size = optimizer. get_schema_of ( left_child. clone ( ) ) . len ( ) ;
188+ let right_schema_size = optimizer. get_schema_of ( right_child. clone ( ) ) . len ( ) ;
189+
190+ // Conditions that only involve the left relation.
191+ let mut left_conds = vec ! [ ] ;
192+ // Conditions that only involve the right relation.
193+ let mut right_conds = vec ! [ ] ;
194+ // Conditions that involve both relations.
195+ let mut keep_conds = vec ! [ ] ;
196+
197+ let categorization_fn = |expr : ArcDfPredNode , children : & [ ArcDfPredNode ] | {
198+ let location = determine_join_cond_dep ( children, left_schema_size, right_schema_size) ;
199+ match location {
200+ JoinCondDependency :: Left => left_conds. push ( expr) ,
201+ JoinCondDependency :: Right => right_conds. push (
202+ expr. rewrite_column_refs ( |idx| {
203+ Some ( LogicalJoin :: map_through_join (
204+ idx,
205+ left_schema_size,
206+ right_schema_size,
207+ ) )
208+ } )
209+ . unwrap ( ) ,
210+ ) ,
211+ JoinCondDependency :: Both => keep_conds. push ( expr) ,
212+ JoinCondDependency :: None => {
213+ unreachable ! ( "join condition should always involve at least one relation" ) ;
214+ }
215+ }
216+ } ;
217+ categorize_conds ( categorization_fn, join_cond) ;
218+
219+ let new_left = if !left_conds. is_empty ( ) {
220+ let new_filter_node =
221+ LogicalFilter :: new_unchecked ( left_child, and_expr_list_to_expr ( left_conds) ) ;
222+ PlanNodeOrGroup :: PlanNode ( new_filter_node. into_plan_node ( ) )
223+ } else {
224+ left_child
225+ } ;
226+
227+ let new_right = if !right_conds. is_empty ( ) {
228+ let new_filter_node =
229+ LogicalFilter :: new_unchecked ( right_child, and_expr_list_to_expr ( right_conds) ) ;
230+ PlanNodeOrGroup :: PlanNode ( new_filter_node. into_plan_node ( ) )
231+ } else {
232+ right_child
233+ } ;
234+
235+ let new_join = LogicalJoin :: new_unchecked (
236+ new_left,
237+ new_right,
238+ and_expr_list_to_expr ( keep_conds) ,
239+ * join_typ,
240+ ) ;
241+
242+ vec ! [ new_join. into_plan_node( ) . into( ) ]
243+ }
163244define_rule ! (
164245 FilterInnerJoinTransposeRule ,
165246 apply_filter_inner_join_transpose,
@@ -442,6 +523,52 @@ mod tests {
442523 assert_eq ! ( col_4. value( ) . as_i32( ) , 1 ) ;
443524 }
444525
526+ #[ test]
527+ fn join_split_filter ( ) {
528+ let mut test_optimizer = new_test_optimizer ( Arc :: new ( LeftOuterJoinSplitFilterRule :: new ( ) ) ) ;
529+
530+ let scan1 = LogicalScan :: new ( "customer" . into ( ) ) ;
531+
532+ let scan2 = LogicalScan :: new ( "orders" . into ( ) ) ;
533+
534+ let join_cond = LogOpPred :: new (
535+ LogOpType :: And ,
536+ vec ! [
537+ BinOpPred :: new(
538+ // This one should be pushed to the left child
539+ ColumnRefPred :: new( 0 ) . into_pred_node( ) ,
540+ ConstantPred :: int32( 5 ) . into_pred_node( ) ,
541+ BinOpType :: Eq ,
542+ )
543+ . into_pred_node( ) ,
544+ BinOpPred :: new(
545+ // This one should be pushed to the right child
546+ ColumnRefPred :: new( 11 ) . into_pred_node( ) ,
547+ ConstantPred :: int32( 6 ) . into_pred_node( ) ,
548+ BinOpType :: Eq ,
549+ )
550+ . into_pred_node( ) ,
551+ BinOpPred :: new(
552+ // This one stay in join condition
553+ ColumnRefPred :: new( 2 ) . into_pred_node( ) ,
554+ ColumnRefPred :: new( 8 ) . into_pred_node( ) ,
555+ BinOpType :: Eq ,
556+ )
557+ . into_pred_node( ) ,
558+ ] ,
559+ ) ;
560+
561+ let join = LogicalJoin :: new (
562+ scan1. into_plan_node ( ) ,
563+ scan2. into_plan_node ( ) ,
564+ join_cond. into_pred_node ( ) ,
565+ super :: JoinType :: LeftOuter ,
566+ ) ;
567+
568+ let plan = test_optimizer. optimize ( join. into_plan_node ( ) ) . unwrap ( ) ;
569+ println ! ( "{}" , plan. explain_to_string( None ) ) ;
570+ }
571+
445572 #[ test]
446573 fn push_past_join_conjunction ( ) {
447574 // Test pushing a complex filter past a join, where one clause can
0 commit comments