@@ -9,7 +9,7 @@ use optd_core::nodes::PlanNodeOrGroup;
9
9
use optd_core:: optimizer:: Optimizer ;
10
10
use optd_core:: rules:: { Rule , RuleMatcher } ;
11
11
12
- use super :: macros:: { define_impl_rule_discriminant , define_rule} ;
12
+ use super :: macros:: { define_impl_rule , define_rule} ;
13
13
use crate :: plan_nodes:: {
14
14
ArcDfPlanNode , BinOpPred , BinOpType , ColumnRefPred , ConstantPred , ConstantType , DfNodeType ,
15
15
DfPredType , DfReprPlanNode , DfReprPredNode , JoinType , ListPred , LogOpType ,
@@ -140,14 +140,241 @@ fn apply_join_assoc(
140
140
vec ! [ node. into_plan_node( ) . into( ) ]
141
141
}
142
142
143
- // Note: this matches all join types despite using `JoinType::Inner` below.
144
- define_impl_rule_discriminant ! (
145
- HashJoinRule ,
146
- apply_hash_join,
143
+ define_impl_rule ! (
144
+ HashJoinInnerRule ,
145
+ apply_hash_join_inner,
147
146
( Join ( JoinType :: Inner ) , left, right)
148
147
) ;
149
148
150
- fn apply_hash_join (
149
+ fn apply_hash_join_inner (
150
+ optimizer : & impl Optimizer < DfNodeType > ,
151
+ binding : ArcDfPlanNode ,
152
+ ) -> Vec < PlanNodeOrGroup < DfNodeType > > {
153
+ let join = LogicalJoin :: from_plan_node ( binding) . unwrap ( ) ;
154
+ let cond = join. cond ( ) ;
155
+ let left = join. left ( ) ;
156
+ let right = join. right ( ) ;
157
+ let join_type = join. join_type ( ) ;
158
+ match cond. typ {
159
+ DfPredType :: BinOp ( BinOpType :: Eq ) => {
160
+ let left_schema = optimizer. get_schema_of ( left. clone ( ) ) ;
161
+ let op = BinOpPred :: from_pred_node ( cond. clone ( ) ) . unwrap ( ) ;
162
+ let left_expr = op. left_child ( ) ;
163
+ let right_expr = op. right_child ( ) ;
164
+ let Some ( mut left_expr) = ColumnRefPred :: from_pred_node ( left_expr) else {
165
+ return vec ! [ ] ;
166
+ } ;
167
+ let Some ( mut right_expr) = ColumnRefPred :: from_pred_node ( right_expr) else {
168
+ return vec ! [ ] ;
169
+ } ;
170
+ let can_convert = if left_expr. index ( ) < left_schema. len ( )
171
+ && right_expr. index ( ) >= left_schema. len ( )
172
+ {
173
+ true
174
+ } else if right_expr. index ( ) < left_schema. len ( )
175
+ && left_expr. index ( ) >= left_schema. len ( )
176
+ {
177
+ ( left_expr, right_expr) = ( right_expr, left_expr) ;
178
+ true
179
+ } else {
180
+ false
181
+ } ;
182
+
183
+ if can_convert {
184
+ let right_expr = ColumnRefPred :: new ( right_expr. index ( ) - left_schema. len ( ) ) ;
185
+ let node = PhysicalHashJoin :: new_unchecked (
186
+ left,
187
+ right,
188
+ ListPred :: new ( vec ! [ left_expr. into_pred_node( ) ] ) ,
189
+ ListPred :: new ( vec ! [ right_expr. into_pred_node( ) ] ) ,
190
+ * join_type,
191
+ ) ;
192
+ return vec ! [ node. into_plan_node( ) . into( ) ] ;
193
+ }
194
+ }
195
+ DfPredType :: LogOp ( LogOpType :: And ) => {
196
+ // currently only support consecutive equal queries
197
+ let mut is_consecutive_eq = true ;
198
+ for child in cond. children . clone ( ) {
199
+ if let DfPredType :: BinOp ( BinOpType :: Eq ) = child. typ {
200
+ continue ;
201
+ } else {
202
+ is_consecutive_eq = false ;
203
+ break ;
204
+ }
205
+ }
206
+ if !is_consecutive_eq {
207
+ return vec ! [ ] ;
208
+ }
209
+
210
+ let left_schema = optimizer. get_schema_of ( left. clone ( ) ) ;
211
+ let mut left_exprs = vec ! [ ] ;
212
+ let mut right_exprs = vec ! [ ] ;
213
+ for child in & cond. children {
214
+ let bin_op = BinOpPred :: from_pred_node ( child. clone ( ) ) . unwrap ( ) ;
215
+ let left_expr = bin_op. left_child ( ) ;
216
+ let right_expr = bin_op. right_child ( ) ;
217
+ let Some ( mut left_expr) = ColumnRefPred :: from_pred_node ( left_expr) else {
218
+ return vec ! [ ] ;
219
+ } ;
220
+ let Some ( mut right_expr) = ColumnRefPred :: from_pred_node ( right_expr) else {
221
+ return vec ! [ ] ;
222
+ } ;
223
+ let can_convert = if left_expr. index ( ) < left_schema. len ( )
224
+ && right_expr. index ( ) >= left_schema. len ( )
225
+ {
226
+ true
227
+ } else if right_expr. index ( ) < left_schema. len ( )
228
+ && left_expr. index ( ) >= left_schema. len ( )
229
+ {
230
+ ( left_expr, right_expr) = ( right_expr, left_expr) ;
231
+ true
232
+ } else {
233
+ false
234
+ } ;
235
+ if !can_convert {
236
+ return vec ! [ ] ;
237
+ }
238
+ let right_expr = ColumnRefPred :: new ( right_expr. index ( ) - left_schema. len ( ) ) ;
239
+ right_exprs. push ( right_expr. into_pred_node ( ) ) ;
240
+ left_exprs. push ( left_expr. into_pred_node ( ) ) ;
241
+ }
242
+
243
+ let node = PhysicalHashJoin :: new_unchecked (
244
+ left,
245
+ right,
246
+ ListPred :: new ( left_exprs) ,
247
+ ListPred :: new ( right_exprs) ,
248
+ * join_type,
249
+ ) ;
250
+ return vec ! [ node. into_plan_node( ) . into( ) ] ;
251
+ }
252
+ _ => { }
253
+ }
254
+ vec ! [ ]
255
+ }
256
+
257
+ define_impl_rule ! (
258
+ HashJoinLeftOuterRule ,
259
+ apply_hash_join_left_outer,
260
+ ( Join ( JoinType :: LeftOuter ) , left, right)
261
+ ) ;
262
+
263
+ fn apply_hash_join_left_outer (
264
+ optimizer : & impl Optimizer < DfNodeType > ,
265
+ binding : ArcDfPlanNode ,
266
+ ) -> Vec < PlanNodeOrGroup < DfNodeType > > {
267
+ let join = LogicalJoin :: from_plan_node ( binding) . unwrap ( ) ;
268
+ let cond = join. cond ( ) ;
269
+ let left = join. left ( ) ;
270
+ let right = join. right ( ) ;
271
+ let join_type = join. join_type ( ) ;
272
+ match cond. typ {
273
+ DfPredType :: BinOp ( BinOpType :: Eq ) => {
274
+ let left_schema = optimizer. get_schema_of ( left. clone ( ) ) ;
275
+ let op = BinOpPred :: from_pred_node ( cond. clone ( ) ) . unwrap ( ) ;
276
+ let left_expr = op. left_child ( ) ;
277
+ let right_expr = op. right_child ( ) ;
278
+ let Some ( mut left_expr) = ColumnRefPred :: from_pred_node ( left_expr) else {
279
+ return vec ! [ ] ;
280
+ } ;
281
+ let Some ( mut right_expr) = ColumnRefPred :: from_pred_node ( right_expr) else {
282
+ return vec ! [ ] ;
283
+ } ;
284
+ let can_convert = if left_expr. index ( ) < left_schema. len ( )
285
+ && right_expr. index ( ) >= left_schema. len ( )
286
+ {
287
+ true
288
+ } else if right_expr. index ( ) < left_schema. len ( )
289
+ && left_expr. index ( ) >= left_schema. len ( )
290
+ {
291
+ ( left_expr, right_expr) = ( right_expr, left_expr) ;
292
+ true
293
+ } else {
294
+ false
295
+ } ;
296
+
297
+ if can_convert {
298
+ let right_expr = ColumnRefPred :: new ( right_expr. index ( ) - left_schema. len ( ) ) ;
299
+ let node = PhysicalHashJoin :: new_unchecked (
300
+ left,
301
+ right,
302
+ ListPred :: new ( vec ! [ left_expr. into_pred_node( ) ] ) ,
303
+ ListPred :: new ( vec ! [ right_expr. into_pred_node( ) ] ) ,
304
+ * join_type,
305
+ ) ;
306
+ return vec ! [ node. into_plan_node( ) . into( ) ] ;
307
+ }
308
+ }
309
+ DfPredType :: LogOp ( LogOpType :: And ) => {
310
+ // currently only support consecutive equal queries
311
+ let mut is_consecutive_eq = true ;
312
+ for child in cond. children . clone ( ) {
313
+ if let DfPredType :: BinOp ( BinOpType :: Eq ) = child. typ {
314
+ continue ;
315
+ } else {
316
+ is_consecutive_eq = false ;
317
+ break ;
318
+ }
319
+ }
320
+ if !is_consecutive_eq {
321
+ return vec ! [ ] ;
322
+ }
323
+
324
+ let left_schema = optimizer. get_schema_of ( left. clone ( ) ) ;
325
+ let mut left_exprs = vec ! [ ] ;
326
+ let mut right_exprs = vec ! [ ] ;
327
+ for child in & cond. children {
328
+ let bin_op = BinOpPred :: from_pred_node ( child. clone ( ) ) . unwrap ( ) ;
329
+ let left_expr = bin_op. left_child ( ) ;
330
+ let right_expr = bin_op. right_child ( ) ;
331
+ let Some ( mut left_expr) = ColumnRefPred :: from_pred_node ( left_expr) else {
332
+ return vec ! [ ] ;
333
+ } ;
334
+ let Some ( mut right_expr) = ColumnRefPred :: from_pred_node ( right_expr) else {
335
+ return vec ! [ ] ;
336
+ } ;
337
+ let can_convert = if left_expr. index ( ) < left_schema. len ( )
338
+ && right_expr. index ( ) >= left_schema. len ( )
339
+ {
340
+ true
341
+ } else if right_expr. index ( ) < left_schema. len ( )
342
+ && left_expr. index ( ) >= left_schema. len ( )
343
+ {
344
+ ( left_expr, right_expr) = ( right_expr, left_expr) ;
345
+ true
346
+ } else {
347
+ false
348
+ } ;
349
+ if !can_convert {
350
+ return vec ! [ ] ;
351
+ }
352
+ let right_expr = ColumnRefPred :: new ( right_expr. index ( ) - left_schema. len ( ) ) ;
353
+ right_exprs. push ( right_expr. into_pred_node ( ) ) ;
354
+ left_exprs. push ( left_expr. into_pred_node ( ) ) ;
355
+ }
356
+
357
+ let node = PhysicalHashJoin :: new_unchecked (
358
+ left,
359
+ right,
360
+ ListPred :: new ( left_exprs) ,
361
+ ListPred :: new ( right_exprs) ,
362
+ * join_type,
363
+ ) ;
364
+ return vec ! [ node. into_plan_node( ) . into( ) ] ;
365
+ }
366
+ _ => { }
367
+ }
368
+ vec ! [ ]
369
+ }
370
+
371
+ define_impl_rule ! (
372
+ HashJoinLeftMarkRule ,
373
+ apply_hash_join_left_mark,
374
+ ( Join ( JoinType :: LeftMark ) , left, right)
375
+ ) ;
376
+
377
+ fn apply_hash_join_left_mark (
151
378
optimizer : & impl Optimizer < DfNodeType > ,
152
379
binding : ArcDfPlanNode ,
153
380
) -> Vec < PlanNodeOrGroup < DfNodeType > > {
0 commit comments