Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 8138ae3

Browse files
committed
Revert "use define_impl_rule_discriminant"
This reverts commit 973e80c.
1 parent 973e80c commit 8138ae3

File tree

3 files changed

+240
-18
lines changed

3 files changed

+240
-18
lines changed

optd-datafusion-repr/src/lib.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ impl DatafusionOptimizer {
105105
rule_wrappers.push(Arc::new(rules::FilterInnerJoinTransposeRule::new()));
106106
rule_wrappers.push(Arc::new(rules::FilterSortTransposeRule::new()));
107107
rule_wrappers.push(Arc::new(rules::FilterAggTransposeRule::new()));
108-
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
108+
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
109+
rule_wrappers.push(Arc::new(rules::HashJoinLeftOuterRule::new()));
110+
rule_wrappers.push(Arc::new(rules::HashJoinLeftMarkRule::new()));
109111
rule_wrappers.push(Arc::new(rules::JoinCommuteRule::new()));
110112
rule_wrappers.push(Arc::new(rules::JoinAssocRule::new()));
111113
rule_wrappers.push(Arc::new(rules::ProjectionPullUpJoin::new()));
@@ -177,7 +179,7 @@ impl DatafusionOptimizer {
177179
for rule in rules {
178180
rule_wrappers.push(rule);
179181
}
180-
rule_wrappers.push(Arc::new(rules::HashJoinRule::new()));
182+
rule_wrappers.push(Arc::new(rules::HashJoinInnerRule::new()));
181183
rule_wrappers.insert(0, Arc::new(rules::JoinCommuteRule::new()));
182184
rule_wrappers.insert(1, Arc::new(rules::JoinAssocRule::new()));
183185
rule_wrappers.insert(2, Arc::new(rules::ProjectionPullUpJoin::new()));

optd-datafusion-repr/src/rules/joins.rs

+233-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use optd_core::nodes::PlanNodeOrGroup;
99
use optd_core::optimizer::Optimizer;
1010
use optd_core::rules::{Rule, RuleMatcher};
1111

12-
use super::macros::{define_impl_rule_discriminant, define_rule};
12+
use super::macros::{define_impl_rule, define_rule};
1313
use crate::plan_nodes::{
1414
ArcDfPlanNode, BinOpPred, BinOpType, ColumnRefPred, ConstantPred, ConstantType, DfNodeType,
1515
DfPredType, DfReprPlanNode, DfReprPredNode, JoinType, ListPred, LogOpType,
@@ -140,14 +140,241 @@ fn apply_join_assoc(
140140
vec![node.into_plan_node().into()]
141141
}
142142

143-
// Note: this matches all join types despite using `JoinType::Inner` below.
144-
define_impl_rule_discriminant!(
145-
HashJoinRule,
146-
apply_hash_join,
143+
define_impl_rule!(
144+
HashJoinInnerRule,
145+
apply_hash_join_inner,
147146
(Join(JoinType::Inner), left, right)
148147
);
149148

150-
fn apply_hash_join(
149+
fn apply_hash_join_inner(
150+
optimizer: &impl Optimizer<DfNodeType>,
151+
binding: ArcDfPlanNode,
152+
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
153+
let join = LogicalJoin::from_plan_node(binding).unwrap();
154+
let cond = join.cond();
155+
let left = join.left();
156+
let right = join.right();
157+
let join_type = join.join_type();
158+
match cond.typ {
159+
DfPredType::BinOp(BinOpType::Eq) => {
160+
let left_schema = optimizer.get_schema_of(left.clone());
161+
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
162+
let left_expr = op.left_child();
163+
let right_expr = op.right_child();
164+
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
165+
return vec![];
166+
};
167+
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
168+
return vec![];
169+
};
170+
let can_convert = if left_expr.index() < left_schema.len()
171+
&& right_expr.index() >= left_schema.len()
172+
{
173+
true
174+
} else if right_expr.index() < left_schema.len()
175+
&& left_expr.index() >= left_schema.len()
176+
{
177+
(left_expr, right_expr) = (right_expr, left_expr);
178+
true
179+
} else {
180+
false
181+
};
182+
183+
if can_convert {
184+
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
185+
let node = PhysicalHashJoin::new_unchecked(
186+
left,
187+
right,
188+
ListPred::new(vec![left_expr.into_pred_node()]),
189+
ListPred::new(vec![right_expr.into_pred_node()]),
190+
*join_type,
191+
);
192+
return vec![node.into_plan_node().into()];
193+
}
194+
}
195+
DfPredType::LogOp(LogOpType::And) => {
196+
// currently only support consecutive equal queries
197+
let mut is_consecutive_eq = true;
198+
for child in cond.children.clone() {
199+
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
200+
continue;
201+
} else {
202+
is_consecutive_eq = false;
203+
break;
204+
}
205+
}
206+
if !is_consecutive_eq {
207+
return vec![];
208+
}
209+
210+
let left_schema = optimizer.get_schema_of(left.clone());
211+
let mut left_exprs = vec![];
212+
let mut right_exprs = vec![];
213+
for child in &cond.children {
214+
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
215+
let left_expr = bin_op.left_child();
216+
let right_expr = bin_op.right_child();
217+
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
218+
return vec![];
219+
};
220+
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
221+
return vec![];
222+
};
223+
let can_convert = if left_expr.index() < left_schema.len()
224+
&& right_expr.index() >= left_schema.len()
225+
{
226+
true
227+
} else if right_expr.index() < left_schema.len()
228+
&& left_expr.index() >= left_schema.len()
229+
{
230+
(left_expr, right_expr) = (right_expr, left_expr);
231+
true
232+
} else {
233+
false
234+
};
235+
if !can_convert {
236+
return vec![];
237+
}
238+
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
239+
right_exprs.push(right_expr.into_pred_node());
240+
left_exprs.push(left_expr.into_pred_node());
241+
}
242+
243+
let node = PhysicalHashJoin::new_unchecked(
244+
left,
245+
right,
246+
ListPred::new(left_exprs),
247+
ListPred::new(right_exprs),
248+
*join_type,
249+
);
250+
return vec![node.into_plan_node().into()];
251+
}
252+
_ => {}
253+
}
254+
vec![]
255+
}
256+
257+
define_impl_rule!(
258+
HashJoinLeftOuterRule,
259+
apply_hash_join_left_outer,
260+
(Join(JoinType::LeftOuter), left, right)
261+
);
262+
263+
fn apply_hash_join_left_outer(
264+
optimizer: &impl Optimizer<DfNodeType>,
265+
binding: ArcDfPlanNode,
266+
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
267+
let join = LogicalJoin::from_plan_node(binding).unwrap();
268+
let cond = join.cond();
269+
let left = join.left();
270+
let right = join.right();
271+
let join_type = join.join_type();
272+
match cond.typ {
273+
DfPredType::BinOp(BinOpType::Eq) => {
274+
let left_schema = optimizer.get_schema_of(left.clone());
275+
let op = BinOpPred::from_pred_node(cond.clone()).unwrap();
276+
let left_expr = op.left_child();
277+
let right_expr = op.right_child();
278+
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
279+
return vec![];
280+
};
281+
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
282+
return vec![];
283+
};
284+
let can_convert = if left_expr.index() < left_schema.len()
285+
&& right_expr.index() >= left_schema.len()
286+
{
287+
true
288+
} else if right_expr.index() < left_schema.len()
289+
&& left_expr.index() >= left_schema.len()
290+
{
291+
(left_expr, right_expr) = (right_expr, left_expr);
292+
true
293+
} else {
294+
false
295+
};
296+
297+
if can_convert {
298+
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
299+
let node = PhysicalHashJoin::new_unchecked(
300+
left,
301+
right,
302+
ListPred::new(vec![left_expr.into_pred_node()]),
303+
ListPred::new(vec![right_expr.into_pred_node()]),
304+
*join_type,
305+
);
306+
return vec![node.into_plan_node().into()];
307+
}
308+
}
309+
DfPredType::LogOp(LogOpType::And) => {
310+
// currently only support consecutive equal queries
311+
let mut is_consecutive_eq = true;
312+
for child in cond.children.clone() {
313+
if let DfPredType::BinOp(BinOpType::Eq) = child.typ {
314+
continue;
315+
} else {
316+
is_consecutive_eq = false;
317+
break;
318+
}
319+
}
320+
if !is_consecutive_eq {
321+
return vec![];
322+
}
323+
324+
let left_schema = optimizer.get_schema_of(left.clone());
325+
let mut left_exprs = vec![];
326+
let mut right_exprs = vec![];
327+
for child in &cond.children {
328+
let bin_op = BinOpPred::from_pred_node(child.clone()).unwrap();
329+
let left_expr = bin_op.left_child();
330+
let right_expr = bin_op.right_child();
331+
let Some(mut left_expr) = ColumnRefPred::from_pred_node(left_expr) else {
332+
return vec![];
333+
};
334+
let Some(mut right_expr) = ColumnRefPred::from_pred_node(right_expr) else {
335+
return vec![];
336+
};
337+
let can_convert = if left_expr.index() < left_schema.len()
338+
&& right_expr.index() >= left_schema.len()
339+
{
340+
true
341+
} else if right_expr.index() < left_schema.len()
342+
&& left_expr.index() >= left_schema.len()
343+
{
344+
(left_expr, right_expr) = (right_expr, left_expr);
345+
true
346+
} else {
347+
false
348+
};
349+
if !can_convert {
350+
return vec![];
351+
}
352+
let right_expr = ColumnRefPred::new(right_expr.index() - left_schema.len());
353+
right_exprs.push(right_expr.into_pred_node());
354+
left_exprs.push(left_expr.into_pred_node());
355+
}
356+
357+
let node = PhysicalHashJoin::new_unchecked(
358+
left,
359+
right,
360+
ListPred::new(left_exprs),
361+
ListPred::new(right_exprs),
362+
*join_type,
363+
);
364+
return vec![node.into_plan_node().into()];
365+
}
366+
_ => {}
367+
}
368+
vec![]
369+
}
370+
371+
define_impl_rule!(
372+
HashJoinLeftMarkRule,
373+
apply_hash_join_left_mark,
374+
(Join(JoinType::LeftMark), left, right)
375+
);
376+
377+
fn apply_hash_join_left_mark(
151378
optimizer: &impl Optimizer<DfNodeType>,
152379
binding: ArcDfPlanNode,
153380
) -> Vec<PlanNodeOrGroup<DfNodeType>> {

optd-datafusion-repr/src/rules/macros.rs

+3-10
Original file line numberDiff line numberDiff line change
@@ -79,19 +79,12 @@ macro_rules! define_rule_discriminant {
7979
};
8080
}
8181

82-
// macro_rules! define_impl_rule {
83-
// ($name:ident, $apply:ident, $($matcher:tt)+) => {
84-
// crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ }
85-
// };
86-
// }
87-
88-
macro_rules! define_impl_rule_discriminant {
82+
macro_rules! define_impl_rule {
8983
($name:ident, $apply:ident, $($matcher:tt)+) => {
90-
crate::rules::macros::define_rule_inner! { true, true, $name, $apply, $($matcher)+ }
84+
crate::rules::macros::define_rule_inner! { true, false, $name, $apply, $($matcher)+ }
9185
};
9286
}
9387

9488
pub(crate) use {
95-
define_impl_rule_discriminant, define_matcher, define_rule, define_rule_discriminant,
96-
define_rule_inner,
89+
define_impl_rule, define_matcher, define_rule, define_rule_discriminant, define_rule_inner,
9790
};

0 commit comments

Comments
 (0)