Skip to content

Commit 302ad4c

Browse files
committed
introduce lateInterationScore
1 parent f1825fd commit 302ad4c

File tree

9 files changed

+624
-1
lines changed

9 files changed

+624
-1
lines changed

modules/lang-painless/src/main/java/org/opensearch/painless/PainlessModulePlugin.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ public final class PainlessModulePlugin extends Plugin implements ScriptPlugin,
103103
// Functions used for scoring docs
104104
List<Allowlist> scoreFn = new ArrayList<>(Allowlist.BASE_ALLOWLISTS);
105105
scoreFn.add(AllowlistLoader.loadFromResourceFiles(Allowlist.class, "org.opensearch.score.txt"));
106+
107+
Allowlist vectorFunctions = AllowlistLoader.loadFromResourceFiles(
108+
Allowlist.class,
109+
"org.opensearch.painless.vector_functions.txt"
110+
);
111+
scoreFn.add(vectorFunctions);
112+
106113
map.put(ScoreScript.CONTEXT, scoreFn);
107114

108115
// Functions available to ingest pipelines
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.painless.functions;
10+
11+
import java.util.List;
12+
import java.util.Map;
13+
14+
public final class PainlessVectorFunctions {
15+
16+
private PainlessVectorFunctions() {}
17+
18+
/**
19+
* Calculates the late interaction score between query vectors and document vectors.
20+
* For each query vector, finds the maximum dot product with any document vector and sums these maxima.
21+
* This implements a ColBERT-style late interaction pattern for token-level matching.
22+
*
23+
* @param queryVectors List of query vectors, each a list of doubles
24+
* @param docFieldName Name of the field in the document containing vectors
25+
* @param doc Document source as a map
26+
* @return Sum of maximum similarity scores
27+
*/
28+
@SuppressWarnings("unchecked")
29+
public static double lateInteractionScore(List<List<Double>> queryVectors, String docFieldName, Map<String, Object> doc) {
30+
if (queryVectors == null || queryVectors.isEmpty()) {
31+
return 0.0;
32+
}
33+
34+
double totalMaxSim = 0.0;
35+
List<List<Double>> docVectors = (List<List<Double>>) doc.get(docFieldName);
36+
37+
if (docVectors == null || docVectors.isEmpty()) {
38+
return 0.0;
39+
}
40+
41+
for (List<Double> q_vec : queryVectors) {
42+
if (q_vec == null || q_vec.isEmpty()) {
43+
continue;
44+
}
45+
46+
double maxDocTokenSim = 0.0;
47+
48+
for (List<Double> doc_token_vec : docVectors) {
49+
if (doc_token_vec == null || doc_token_vec.isEmpty()) {
50+
continue;
51+
}
52+
53+
double currentSim = 0.0;
54+
if (q_vec.size() == doc_token_vec.size()) {
55+
for (int k = 0; k < q_vec.size(); k++) {
56+
currentSim += q_vec.get(k) * doc_token_vec.get(k);
57+
}
58+
} else {
59+
// Handle dimension mismatch, perhaps log a warning or return a specific value
60+
// For now, as per original script, if dimensions mismatch, currentSim remains 0.0
61+
currentSim = 0.0;
62+
}
63+
64+
if (currentSim > maxDocTokenSim) {
65+
maxDocTokenSim = currentSim;
66+
}
67+
}
68+
totalMaxSim += maxDocTokenSim;
69+
}
70+
return totalMaxSim;
71+
}
72+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# The OpenSearch Contributors require contributions made to
5+
# this file be licensed under the Apache-2.0 license or a
6+
# compatible open source license.
7+
#
8+
9+
# This file contains an allowlist for vector functions to be used in Score context
10+
11+
class org.opensearch.painless.functions.PainlessVectorFunctions @no_import {
12+
double lateInteractionScore(List, String, Map)
13+
}

modules/lang-painless/src/main/resources/org/opensearch/painless/spi/org.opensearch.score.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ static_import {
3939
double decayDateLinear(String, String, String, double, ZonedDateTime) bound_to org.opensearch.script.ScoreScriptUtils$DecayDateLinear
4040
double decayDateExp(String, String, String, double, ZonedDateTime) bound_to org.opensearch.script.ScoreScriptUtils$DecayDateExp
4141
double decayDateGauss(String, String, String, double, ZonedDateTime) bound_to org.opensearch.script.ScoreScriptUtils$DecayDateGauss
42+
double lateInteractionScore(List, String, Map) from_class org.opensearch.painless.functions.PainlessVectorFunctions
4243
}
43-
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# The OpenSearch Contributors require contributions made to
5+
# this file be licensed under the Apache-2.0 license or a
6+
# compatible open source license.
7+
#
8+
9+
# This file contains an allowlist for vector functions to be used in Score context
10+
11+
class org.opensearch.painless.functions.PainlessVectorFunctions @no_import {
12+
double lateInteractionScore(List, String, Map)
13+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.painless;
10+
11+
import org.opensearch.common.settings.Settings;
12+
import org.opensearch.painless.action.PainlessExecuteAction;
13+
import org.opensearch.painless.spi.Allowlist;
14+
import org.opensearch.painless.spi.AllowlistLoader;
15+
import org.opensearch.script.ScoreScript;
16+
import org.opensearch.script.ScriptContext;
17+
import org.opensearch.test.OpenSearchTestCase;
18+
19+
import java.util.ArrayList;
20+
import java.util.Collections;
21+
import java.util.HashMap;
22+
import java.util.List;
23+
import java.util.Map;
24+
25+
public class LateInteractionScoreScriptTests extends ScriptTestCase {
26+
27+
private static final PainlessScriptEngine SCORE_SCRIPT_ENGINE;
28+
29+
static {
30+
Map<ScriptContext<?>, List<Allowlist>> contexts = new HashMap<>();
31+
List<Allowlist> allowlists = new ArrayList<>(Allowlist.BASE_ALLOWLISTS);
32+
allowlists.add(AllowlistLoader.loadFromResourceFiles(Allowlist.class, "org.opensearch.score.txt"));
33+
allowlists.add(AllowlistLoader.loadFromResourceFiles(Allowlist.class, "org.opensearch.painless.vector_functions.txt"));
34+
contexts.put(ScoreScript.CONTEXT, allowlists);
35+
contexts.put(PainlessExecuteAction.PainlessTestScript.CONTEXT, allowlists);
36+
SCORE_SCRIPT_ENGINE = new PainlessScriptEngine(Settings.EMPTY, contexts);
37+
}
38+
39+
@Override
40+
protected PainlessScriptEngine getEngine() {
41+
return SCORE_SCRIPT_ENGINE;
42+
}
43+
44+
public void testLateInteractionScoreInScript() {
45+
String script = "lateInteractionScore(params.query_vector, 'my_vector', params._source)";
46+
47+
// Create query vectors
48+
List<List<Double>> queryVectors = new ArrayList<>();
49+
List<Double> qv1 = new ArrayList<>();
50+
qv1.add(0.1);
51+
qv1.add(0.2);
52+
queryVectors.add(qv1);
53+
54+
// Create document vectors
55+
List<List<Double>> docVectors = new ArrayList<>();
56+
List<Double> dv1 = new ArrayList<>();
57+
dv1.add(0.3);
58+
dv1.add(0.4);
59+
docVectors.add(dv1);
60+
61+
// Create document source
62+
Map<String, Object> doc = new HashMap<>();
63+
doc.put("my_vector", docVectors);
64+
65+
// Create parameters
66+
Map<String, Object> params = new HashMap<>();
67+
params.put("query_vector", queryVectors);
68+
params.put("_source", doc);
69+
70+
// Expected result: 0.1*0.3 + 0.2*0.4 = 0.11
71+
double expected = 0.11;
72+
73+
// Use PainlessTestScript context instead of ScoreScript for testing
74+
PainlessExecuteAction.PainlessTestScript.Factory factory = getEngine().compile(
75+
"test",
76+
script,
77+
PainlessExecuteAction.PainlessTestScript.CONTEXT,
78+
Collections.emptyMap()
79+
);
80+
81+
PainlessExecuteAction.PainlessTestScript testScript = factory.newInstance(params);
82+
double result = ((Number) testScript.execute()).doubleValue();
83+
84+
assertEquals(expected, result, 0.001);
85+
}
86+
87+
public void testLateInteractionScoreWithInlineScript() {
88+
String script = """
89+
if (params.query_vector == null) {
90+
return 0.0;
91+
}
92+
93+
double totalMaxSim = 0.0;
94+
def queryVectors = params.query_vector;
95+
96+
for (int i = 0; i < queryVectors.length; i++) {
97+
def q_vec = queryVectors[i];
98+
99+
double maxDocTokenSim = 0.0;
100+
101+
if (params._source.my_vector == null || params._source.my_vector.length == 0) {
102+
continue;
103+
}
104+
105+
for (int j = 0; j < params._source.my_vector.length; j++) {
106+
def doc_token_vec = params._source.my_vector[j];
107+
108+
double currentSim = 0.0;
109+
if (q_vec.length == doc_token_vec.length) {
110+
for (int k = 0; k < q_vec.length; k++) {
111+
currentSim += q_vec[k] * doc_token_vec[k];
112+
}
113+
} else {
114+
currentSim = 0.0;
115+
}
116+
117+
if (currentSim > maxDocTokenSim) {
118+
maxDocTokenSim = currentSim;
119+
}
120+
}
121+
totalMaxSim += maxDocTokenSim;
122+
}
123+
return totalMaxSim;
124+
""";
125+
126+
// Create query vectors
127+
List<List<Double>> queryVectors = new ArrayList<>();
128+
List<Double> qv1 = new ArrayList<>();
129+
qv1.add(0.1);
130+
qv1.add(0.2);
131+
queryVectors.add(qv1);
132+
133+
// Create document vectors
134+
List<List<Double>> docVectors = new ArrayList<>();
135+
List<Double> dv1 = new ArrayList<>();
136+
dv1.add(0.3);
137+
dv1.add(0.4);
138+
docVectors.add(dv1);
139+
140+
// Create document source
141+
Map<String, Object> doc = new HashMap<>();
142+
doc.put("my_vector", docVectors);
143+
144+
// Create parameters
145+
Map<String, Object> params = new HashMap<>();
146+
params.put("query_vector", queryVectors);
147+
params.put("_source", doc);
148+
149+
// Expected result: 0.1*0.3 + 0.2*0.4 = 0.11
150+
double expected = 0.11;
151+
152+
// Use PainlessTestScript context instead of ScoreScript for testing
153+
PainlessExecuteAction.PainlessTestScript.Factory factory = getEngine().compile(
154+
"test",
155+
script,
156+
PainlessExecuteAction.PainlessTestScript.CONTEXT,
157+
Collections.emptyMap()
158+
);
159+
160+
PainlessExecuteAction.PainlessTestScript testScript = factory.newInstance(params);
161+
double result = ((Number) testScript.execute()).doubleValue();
162+
163+
assertEquals(expected, result, 0.001);
164+
}
165+
}

0 commit comments

Comments
 (0)