Skip to content

Commit 7b7a071

Browse files
[FLINK-37973][table-planner] Change key extractor and recursive multi join logic to support any join tree
1 parent 2ae14e1 commit 7b7a071

File tree

16 files changed

+1850
-547
lines changed

16 files changed

+1850
-547
lines changed
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.calcite.rel.rules;
18+
19+
import com.google.common.collect.ImmutableList;
20+
import com.google.common.collect.ImmutableMap;
21+
import com.google.common.collect.Lists;
22+
import org.apache.calcite.linq4j.Ord;
23+
import org.apache.calcite.plan.Convention;
24+
import org.apache.calcite.plan.RelOptCluster;
25+
import org.apache.calcite.plan.RelTraitSet;
26+
import org.apache.calcite.rel.AbstractRelNode;
27+
import org.apache.calcite.rel.RelNode;
28+
import org.apache.calcite.rel.RelWriter;
29+
import org.apache.calcite.rel.core.JoinRelType;
30+
import org.apache.calcite.rel.type.RelDataType;
31+
import org.apache.calcite.rex.RexNode;
32+
import org.apache.calcite.rex.RexShuttle;
33+
import org.apache.calcite.util.ImmutableBitSet;
34+
import org.apache.calcite.util.ImmutableIntList;
35+
import org.apache.calcite.util.ImmutableNullableList;
36+
import org.checkerframework.checker.nullness.qual.Nullable;
37+
38+
import java.util.ArrayList;
39+
import java.util.Collections;
40+
import java.util.HashMap;
41+
import java.util.List;
42+
import java.util.Map;
43+
44+
import static java.util.Objects.requireNonNull;
45+
46+
/**
47+
* A MultiJoin represents a join of N inputs, whereas regular Joins represent strictly binary joins.
48+
*/
49+
public final class MultiJoin extends AbstractRelNode {
50+
// ~ Instance fields --------------------------------------------------------
51+
52+
private final List<RelNode> inputs;
53+
private final RexNode joinFilter;
54+
55+
@SuppressWarnings("HidingField")
56+
private final RelDataType rowType;
57+
58+
private final boolean isFullOuterJoin;
59+
private final List<@Nullable RexNode> outerJoinConditions;
60+
private final ImmutableList<JoinRelType> joinTypes;
61+
private final List<@Nullable ImmutableBitSet> projFields;
62+
public final ImmutableMap<Integer, ImmutableIntList> joinFieldRefCountsMap;
63+
private final @Nullable RexNode postJoinFilter;
64+
private final List<Integer> levels;
65+
66+
// ~ Constructors -----------------------------------------------------------
67+
68+
/**
69+
* Constructs a MultiJoin.
70+
*
71+
* @param cluster cluster that join belongs to
72+
* @param inputs inputs into this multi-join
73+
* @param joinFilter join filter applicable to this join node
74+
* @param rowType row type of the join result of this node
75+
* @param isFullOuterJoin true if the join is a full outer join
76+
* @param outerJoinConditions outer join condition associated with each join input, if the input
77+
* is null-generating in a left or right outer join; null otherwise
78+
* @param joinTypes the join type corresponding to each input; if an input is null-generating in
79+
* a left or right outer join, the entry indicates the type of outer join; otherwise, the
80+
* entry is set to INNER
81+
* @param projFields fields that will be projected from each input; if null, projection
82+
* information is not available yet so it's assumed that all fields from the input are
83+
* projected
84+
* @param joinFieldRefCountsMap counters of the number of times each field is referenced in join
85+
* conditions, indexed by the input #
86+
* @param postJoinFilter filter to be applied after the joins are
87+
*/
88+
public MultiJoin(
89+
RelOptCluster cluster,
90+
List<RelNode> inputs,
91+
RexNode joinFilter,
92+
RelDataType rowType,
93+
boolean isFullOuterJoin,
94+
List<? extends @Nullable RexNode> outerJoinConditions,
95+
List<JoinRelType> joinTypes,
96+
List<? extends @Nullable ImmutableBitSet> projFields,
97+
ImmutableMap<Integer, ImmutableIntList> joinFieldRefCountsMap,
98+
@Nullable RexNode postJoinFilter) {
99+
this(
100+
cluster,
101+
inputs,
102+
joinFilter,
103+
rowType,
104+
isFullOuterJoin,
105+
outerJoinConditions,
106+
joinTypes,
107+
projFields,
108+
joinFieldRefCountsMap,
109+
postJoinFilter,
110+
Collections.emptyList());
111+
}
112+
113+
/**
114+
* Constructs a {@code MultiJoin}, which represents a flattened tree of joins for use by the
115+
* {@code JoinToMultiJoinRule}.
116+
*
117+
* @param cluster cluster that join belongs to
118+
* @param inputs inputs into this multi-join
119+
* @param joinFilter join filter applicable to this join node
120+
* @param rowType row type of the join result of this node
121+
* @param isFullOuterJoin true if the join is a full outer join
122+
* @param outerJoinConditions outer join condition associated with each join input, if the input
123+
* is null-generating in a left or right outer join; null otherwise
124+
* @param joinTypes the join type corresponding to each input; if an input is null-generating in
125+
* a left or right outer join, the entry indicates the type of outer join; otherwise, the
126+
* entry is set to INNER
127+
* @param projFields fields that will be projected from each input; if null, projection
128+
* information is not available yet so it's assumed that all fields from the input are
129+
* projected
130+
* @param joinFieldRefCountsMap counters of the number of times each field is referenced in join
131+
* conditions, indexed by the input #
132+
* @param postJoinFilter filter to be applied after the joins are
133+
* @param levels join tree levels
134+
*/
135+
public MultiJoin(
136+
RelOptCluster cluster,
137+
List<RelNode> inputs,
138+
RexNode joinFilter,
139+
RelDataType rowType,
140+
boolean isFullOuterJoin,
141+
List<? extends @Nullable RexNode> outerJoinConditions,
142+
List<JoinRelType> joinTypes,
143+
List<? extends @Nullable ImmutableBitSet> projFields,
144+
ImmutableMap<Integer, ImmutableIntList> joinFieldRefCountsMap,
145+
@Nullable RexNode postJoinFilter,
146+
List<Integer> levels) {
147+
super(cluster, cluster.traitSetOf(Convention.NONE));
148+
this.inputs = Lists.newArrayList(inputs);
149+
this.joinFilter = joinFilter;
150+
this.rowType = rowType;
151+
this.isFullOuterJoin = isFullOuterJoin;
152+
this.outerJoinConditions = ImmutableNullableList.copyOf(outerJoinConditions);
153+
this.levels = levels;
154+
assert outerJoinConditions.size() == inputs.size();
155+
this.joinTypes = ImmutableList.copyOf(joinTypes);
156+
this.projFields = ImmutableNullableList.copyOf(projFields);
157+
this.joinFieldRefCountsMap = joinFieldRefCountsMap;
158+
this.postJoinFilter = postJoinFilter;
159+
}
160+
161+
// ~ Methods ----------------------------------------------------------------
162+
163+
@Override
164+
public void replaceInput(int ordinalInParent, RelNode p) {
165+
inputs.set(ordinalInParent, p);
166+
recomputeDigest();
167+
}
168+
169+
@Override
170+
public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
171+
assert traitSet.containsIfApplicable(Convention.NONE);
172+
return new MultiJoin(
173+
getCluster(),
174+
inputs,
175+
joinFilter,
176+
rowType,
177+
isFullOuterJoin,
178+
outerJoinConditions,
179+
joinTypes,
180+
projFields,
181+
joinFieldRefCountsMap,
182+
postJoinFilter,
183+
levels);
184+
}
185+
186+
/** Returns a deep copy of {@link #joinFieldRefCountsMap}. */
187+
private Map<Integer, int[]> cloneJoinFieldRefCountsMap() {
188+
Map<Integer, int[]> clonedMap = new HashMap<>();
189+
for (int i = 0; i < inputs.size(); i++) {
190+
clonedMap.put(i, requireNonNull(joinFieldRefCountsMap.get(i)).toIntArray());
191+
}
192+
return clonedMap;
193+
}
194+
195+
@Override
196+
public RelWriter explainTerms(RelWriter pw) {
197+
List<String> joinTypeNames = new ArrayList<>();
198+
List<String> outerJoinConds = new ArrayList<>();
199+
List<String> projFieldObjects = new ArrayList<>();
200+
for (int i = 0; i < inputs.size(); i++) {
201+
joinTypeNames.add(joinTypes.get(i).name());
202+
RexNode outerJoinCondition = outerJoinConditions.get(i);
203+
if (outerJoinCondition == null) {
204+
outerJoinConds.add("NULL");
205+
} else {
206+
outerJoinConds.add(outerJoinCondition.toString());
207+
}
208+
ImmutableBitSet projField = projFields.get(i);
209+
if (projField == null) {
210+
projFieldObjects.add("ALL");
211+
} else {
212+
projFieldObjects.add(projField.toString());
213+
}
214+
}
215+
216+
super.explainTerms(pw);
217+
for (Ord<RelNode> ord : Ord.zip(inputs)) {
218+
pw.input("input#" + ord.i, ord.e);
219+
}
220+
return pw.item("joinFilter", joinFilter)
221+
.item("isFullOuterJoin", isFullOuterJoin)
222+
.item("joinTypes", joinTypeNames)
223+
.item("outerJoinConditions", outerJoinConds)
224+
.item("projFields", projFieldObjects)
225+
.itemIf("postJoinFilter", postJoinFilter, postJoinFilter != null);
226+
}
227+
228+
@Override
229+
public RelDataType deriveRowType() {
230+
return rowType;
231+
}
232+
233+
@Override
234+
public List<RelNode> getInputs() {
235+
return inputs;
236+
}
237+
238+
@Override
239+
public RelNode accept(RexShuttle shuttle) {
240+
RexNode joinFilter = shuttle.apply(this.joinFilter);
241+
List<@Nullable RexNode> outerJoinConditions = shuttle.apply(this.outerJoinConditions);
242+
RexNode postJoinFilter = shuttle.apply(this.postJoinFilter);
243+
244+
if (joinFilter == this.joinFilter
245+
&& outerJoinConditions == this.outerJoinConditions
246+
&& postJoinFilter == this.postJoinFilter) {
247+
return this;
248+
}
249+
250+
return new MultiJoin(
251+
getCluster(),
252+
inputs,
253+
joinFilter,
254+
rowType,
255+
isFullOuterJoin,
256+
outerJoinConditions,
257+
joinTypes,
258+
projFields,
259+
joinFieldRefCountsMap,
260+
postJoinFilter,
261+
levels);
262+
}
263+
264+
/** Returns join filters associated with this MultiJoin. */
265+
public RexNode getJoinFilter() {
266+
return joinFilter;
267+
}
268+
269+
/** Returns true if the MultiJoin corresponds to a full outer join. */
270+
public boolean isFullOuterJoin() {
271+
return isFullOuterJoin;
272+
}
273+
274+
/** Returns outer join conditions for null-generating inputs. */
275+
public List<@Nullable RexNode> getOuterJoinConditions() {
276+
return outerJoinConditions;
277+
}
278+
279+
/** Returns join types of each input. */
280+
public List<JoinRelType> getJoinTypes() {
281+
return joinTypes;
282+
}
283+
284+
/**
285+
* Returns bitmaps representing the fields projected from each input; if an entry is null, all
286+
* fields are projected.
287+
*/
288+
public List<@Nullable ImmutableBitSet> getProjFields() {
289+
return projFields;
290+
}
291+
292+
/**
293+
* Returns the map of reference counts for each input, representing the fields accessed in join
294+
* conditions.
295+
*/
296+
public ImmutableMap<Integer, ImmutableIntList> getJoinFieldRefCountsMap() {
297+
return joinFieldRefCountsMap;
298+
}
299+
300+
/**
301+
* Returns a copy of the map of reference counts for each input, representing the fields
302+
* accessed in join conditions.
303+
*/
304+
public Map<Integer, int[]> getCopyJoinFieldRefCountsMap() {
305+
return cloneJoinFieldRefCountsMap();
306+
}
307+
308+
/** Returns post-join filter associated with this MultiJoin. */
309+
public @Nullable RexNode getPostJoinFilter() {
310+
return postJoinFilter;
311+
}
312+
313+
boolean containsOuter() {
314+
for (JoinRelType joinType : joinTypes) {
315+
if (joinType.isOuterJoin()) {
316+
return true;
317+
}
318+
}
319+
return false;
320+
}
321+
322+
public List<Integer> getLevels() {
323+
return levels;
324+
}
325+
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/RelTimeIndicatorConverter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ private RelNode visitMultiJoin(FlinkLogicalMultiJoin multiJoin) {
556556
newJoinConditions,
557557
multiJoin.getJoinTypes(),
558558
newPostJoinFilter,
559+
multiJoin.getLevels(),
559560
multiJoin.getHints());
560561
}
561562

0 commit comments

Comments
 (0)