Skip to content

Commit 0e4e6d7

Browse files
authored
[FLINK-38425][table] Add rule to convert correlate node to vector search physical node (#27121)
1 parent 11a68b6 commit 0e4e6d7

File tree

9 files changed

+950
-30
lines changed

9 files changed

+950
-30
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/ml/SqlVectorSearchTableFunction.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,19 +195,28 @@ public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFail
195195
throwOnFailure);
196196
}
197197

198-
// check topK is literal
198+
// check top_k is a positive integer literal
199199
LogicalType topKType = toLogicalType(callBinding.getOperandType(3));
200200
if (!operands.get(3).getKind().equals(SqlKind.LITERAL)
201201
|| !topKType.is(LogicalTypeRoot.INTEGER)) {
202202
return SqlValidatorUtils.throwExceptionOrReturnFalse(
203203
Optional.of(
204204
new ValidationException(
205205
String.format(
206-
"Expect parameter topK is integer literal in VECTOR_SEARCH, but it is %s with type %s.",
206+
"Expect parameter top_k is an INTEGER NOT NULL literal in VECTOR_SEARCH, but it is %s with type %s.",
207207
operands.get(3), topKType))),
208208
throwOnFailure);
209209
}
210-
210+
Integer topK = callBinding.getOperandLiteralValue(3, Integer.class);
211+
if (topK == null || topK <= 0) {
212+
return SqlValidatorUtils.throwExceptionOrReturnFalse(
213+
Optional.of(
214+
new ValidationException(
215+
String.format(
216+
"Parameter top_k must be greater than 0, but was %s.",
217+
topK))),
218+
throwOnFailure);
219+
}
211220
return true;
212221
}
213222

@@ -218,7 +227,8 @@ public SqlOperandCountRange getOperandCountRange() {
218227

219228
@Override
220229
public String getAllowedSignatures(SqlOperator op, String opName) {
221-
return opName + "(TABLE table_name, DESCRIPTOR(query_column), search_column, top_k)";
230+
return opName
231+
+ "(TABLE search_table, DESCRIPTOR(column_to_search), column_to_query, top_k)";
222232
}
223233

224234
@Override
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.nodes.exec.spec;
20+
21+
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam;
22+
23+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
24+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnore;
25+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
26+
27+
import org.apache.calcite.rel.core.JoinRelType;
28+
29+
import java.util.Map;
30+
31+
/** VectorSearchSpec describes how vector search is performed. */
32+
public class VectorSearchSpec {
33+
34+
public static final String FIELD_NAME_JOIN_TYPE = "joinType";
35+
public static final String FIELD_NAME_SEARCH_COLUMNS = "searchColumns";
36+
public static final String FIELD_NAME_TOP_K = "topK";
37+
38+
@JsonProperty(FIELD_NAME_JOIN_TYPE)
39+
private final JoinRelType joinRelType;
40+
41+
/** KV: column_to_search -> column_to_query. */
42+
@JsonProperty(FIELD_NAME_SEARCH_COLUMNS)
43+
private final Map<Integer, FunctionParam> searchColumns;
44+
45+
@JsonProperty(FIELD_NAME_TOP_K)
46+
private final FunctionParam topK;
47+
48+
@JsonCreator
49+
public VectorSearchSpec(
50+
@JsonProperty(FIELD_NAME_JOIN_TYPE) JoinRelType joinRelType,
51+
@JsonProperty(FIELD_NAME_SEARCH_COLUMNS) Map<Integer, FunctionParam> searchColumns,
52+
@JsonProperty(FIELD_NAME_TOP_K) FunctionParam topK) {
53+
this.joinRelType = joinRelType;
54+
this.searchColumns = searchColumns;
55+
this.topK = topK;
56+
}
57+
58+
@JsonIgnore
59+
public JoinRelType getJoinType() {
60+
return joinRelType;
61+
}
62+
63+
@JsonIgnore
64+
public Map<Integer, FunctionParam> getSearchColumns() {
65+
return searchColumns;
66+
}
67+
68+
@JsonIgnore
69+
public FunctionParam getTopK() {
70+
return topK;
71+
}
72+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.nodes.physical.stream;
20+
21+
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
22+
import org.apache.flink.table.planner.plan.nodes.exec.spec.VectorSearchSpec;
23+
import org.apache.flink.table.planner.plan.schema.TableSourceTable;
24+
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
25+
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil;
26+
import org.apache.flink.table.planner.plan.utils.RelExplainUtil;
27+
28+
import org.apache.calcite.plan.RelOptCluster;
29+
import org.apache.calcite.plan.RelOptTable;
30+
import org.apache.calcite.plan.RelTraitSet;
31+
import org.apache.calcite.rel.RelNode;
32+
import org.apache.calcite.rel.RelWriter;
33+
import org.apache.calcite.rel.SingleRel;
34+
import org.apache.calcite.rel.type.RelDataType;
35+
import org.apache.calcite.rex.RexProgram;
36+
37+
import javax.annotation.Nullable;
38+
39+
import java.util.List;
40+
import java.util.stream.Collectors;
41+
42+
/** Stream physical RelNode for vector search table function. */
43+
public class StreamPhysicalVectorSearchTableFunction extends SingleRel
44+
implements StreamPhysicalRel {
45+
46+
private final RelOptTable searchTable;
47+
private final @Nullable RexProgram calcProgram;
48+
private final VectorSearchSpec vectorSearchSpec;
49+
private final RelDataType outputRowType;
50+
51+
public StreamPhysicalVectorSearchTableFunction(
52+
RelOptCluster cluster,
53+
RelTraitSet traits,
54+
RelNode input,
55+
RelOptTable searchTable,
56+
@Nullable RexProgram calcProgram,
57+
VectorSearchSpec vectorSearchSpec,
58+
RelDataType outputRowType) {
59+
super(cluster, traits, input);
60+
this.searchTable = searchTable;
61+
this.calcProgram = calcProgram;
62+
this.vectorSearchSpec = vectorSearchSpec;
63+
this.outputRowType = outputRowType;
64+
}
65+
66+
@Override
67+
public RelNode copy(RelTraitSet traitSet, List<RelNode> inputs) {
68+
return new StreamPhysicalVectorSearchTableFunction(
69+
getCluster(),
70+
traitSet,
71+
inputs.get(0),
72+
searchTable,
73+
calcProgram,
74+
vectorSearchSpec,
75+
outputRowType);
76+
}
77+
78+
@Override
79+
protected RelDataType deriveRowType() {
80+
return outputRowType;
81+
}
82+
83+
@Override
84+
public RelWriter explainTerms(RelWriter pw) {
85+
List<String> columnToSearch =
86+
vectorSearchSpec.getSearchColumns().keySet().stream()
87+
.map(
88+
calcProgram == null
89+
? searchTable.getRowType().getFieldNames()::get
90+
: calcProgram.getOutputRowType().getFieldNames()::get)
91+
.collect(Collectors.toList());
92+
List<String> columnToQuery =
93+
vectorSearchSpec.getSearchColumns().values().stream()
94+
.map(this::explainQueryColumnParam)
95+
.collect(Collectors.toList());
96+
97+
Integer topK =
98+
((FunctionCallUtil.Constant) vectorSearchSpec.getTopK())
99+
.literal.getValueAs(Integer.class);
100+
101+
String leftSelect = String.join(", ", getInput(0).getRowType().getFieldNames());
102+
String rightSelect =
103+
calcProgram == null
104+
? String.join(", ", searchTable.getRowType().getFieldNames())
105+
: RelExplainUtil.selectionToString(
106+
calcProgram,
107+
this::getExpressionString,
108+
RelExplainUtil.preferExpressionFormat(pw),
109+
convertToExpressionDetail(pw.getDetailLevel()));
110+
111+
return super.explainTerms(pw)
112+
.item(
113+
"table",
114+
((TableSourceTable) searchTable)
115+
.contextResolvedTable()
116+
.getIdentifier()
117+
.asSummaryString())
118+
.item("joinType", JoinTypeUtil.getFlinkJoinType(vectorSearchSpec.getJoinType()))
119+
.item("columnToSearch", String.join(", ", columnToSearch))
120+
.item("columnToQuery", String.join(", ", columnToQuery))
121+
.item("topK", topK)
122+
.item("select", String.join(", ", leftSelect, rightSelect, "score"));
123+
}
124+
125+
@Override
126+
public boolean requireWatermark() {
127+
return false;
128+
}
129+
130+
@Override
131+
public ExecNode<?> translateToExecNode() {
132+
throw new UnsupportedOperationException("Vector search not supported yet.");
133+
}
134+
135+
private String explainQueryColumnParam(FunctionCallUtil.FunctionParam param) {
136+
if (param instanceof FunctionCallUtil.FieldRef) {
137+
int index = ((FunctionCallUtil.FieldRef) param).index;
138+
return getInput(0).getRowType().getFieldNames().get(index);
139+
} else if (param instanceof FunctionCallUtil.Constant) {
140+
return ((FunctionCallUtil.Constant) param).literal.toString();
141+
}
142+
return null;
143+
}
144+
}

0 commit comments

Comments
 (0)