Skip to content

Commit 07d44de

Browse files
committed
move Full Precision vector sim values source to a separate class
1 parent da62c29 commit 07d44de

File tree

4 files changed

+172
-249
lines changed

4 files changed

+172
-249
lines changed

lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityValuesSource.java

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,54 +21,18 @@
2121
import java.util.Arrays;
2222
import java.util.Objects;
2323
import org.apache.lucene.index.ByteVectorValues;
24-
import org.apache.lucene.index.FieldInfo;
25-
import org.apache.lucene.index.KnnVectorValues;
2624
import org.apache.lucene.index.LeafReaderContext;
27-
import org.apache.lucene.index.VectorSimilarityFunction;
2825

2926
/**
3027
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
3128
* and the {@link org.apache.lucene.document.KnnByteVectorField} for documents.
3229
*/
3330
class ByteVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
34-
35-
/**
36-
* Creates a {@link ByteVectorSimilarityValuesSource} that scores on full precision vector values
37-
*/
38-
public static DoubleValues fullPrecisionScores(
39-
LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException {
40-
return new ByteVectorSimilarityValuesSource(queryVector, vectorField, true)
41-
.getValues(ctx, null);
42-
}
43-
4431
private final byte[] queryVector;
45-
private final boolean useFullPrecision;
4632

47-
/**
48-
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
49-
* query vector and field for documents. Uses the scorer exposed by configured vectors reader.
50-
*
51-
* @param vector the query vector
52-
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnByteVectorField}
53-
*/
5433
public ByteVectorSimilarityValuesSource(byte[] vector, String fieldName) {
55-
this(vector, fieldName, false);
56-
}
57-
58-
/**
59-
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
60-
* query vector and field for documents.
61-
*
62-
* @param vector the query vector
63-
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnByteVectorField}
64-
* @param useFullPrecision uses full precision raw vectors for similarity computation if true,
65-
* otherwise the configured vectors reader is used, which may be quantized or full precision.
66-
*/
67-
public ByteVectorSimilarityValuesSource(
68-
byte[] vector, String fieldName, boolean useFullPrecision) {
6934
super(fieldName);
7035
this.queryVector = vector;
71-
this.useFullPrecision = useFullPrecision;
7236
}
7337

7438
@Override
@@ -78,35 +42,7 @@ public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
7842
ByteVectorValues.checkField(ctx.reader(), fieldName);
7943
return null;
8044
}
81-
final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName);
82-
if (fi.getVectorDimension() != queryVector.length) {
83-
throw new IllegalArgumentException(
84-
"Query vector dimension does not match field dimension: "
85-
+ queryVector.length
86-
+ " != "
87-
+ fi.getVectorDimension());
88-
}
89-
90-
// default vector scorer
91-
if (useFullPrecision == false) {
92-
return vectorValues.scorer(queryVector);
93-
}
94-
95-
final VectorSimilarityFunction vectorSimilarityFunction = fi.getVectorSimilarityFunction();
96-
return new VectorScorer() {
97-
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
98-
99-
@Override
100-
public float score() throws IOException {
101-
return vectorSimilarityFunction.compare(
102-
queryVector, vectorValues.vectorValue(iterator.index()));
103-
}
104-
105-
@Override
106-
public DocIdSetIterator iterator() {
107-
return iterator;
108-
}
109-
};
45+
return vectorValues.scorer(queryVector);
11046
}
11147

11248
@Override

lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityValuesSource.java

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,55 +20,20 @@
2020
import java.io.IOException;
2121
import java.util.Arrays;
2222
import java.util.Objects;
23-
import org.apache.lucene.index.FieldInfo;
2423
import org.apache.lucene.index.FloatVectorValues;
25-
import org.apache.lucene.index.KnnVectorValues;
2624
import org.apache.lucene.index.LeafReaderContext;
27-
import org.apache.lucene.index.VectorSimilarityFunction;
2825

2926
/**
3027
* A {@link DoubleValuesSource} which computes the vector similarity scores between the query vector
3128
* and the {@link org.apache.lucene.document.KnnFloatVectorField} for documents.
3229
*/
3330
class FloatVectorSimilarityValuesSource extends VectorSimilarityValuesSource {
3431

35-
/**
36-
* Creates a {@link FloatVectorSimilarityValuesSource} that scores on full precision vector values
37-
*/
38-
public static DoubleValues fullPrecisionScores(
39-
LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException {
40-
return new FloatVectorSimilarityValuesSource(queryVector, vectorField, true)
41-
.getValues(ctx, null);
42-
}
43-
4432
private final float[] queryVector;
45-
private final boolean useFullPrecision;
4633

47-
/**
48-
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
49-
* query vector and field for documents. Uses the scorer exposed by configured vectors reader.
50-
*
51-
* @param vector the query vector
52-
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
53-
*/
5434
public FloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
55-
this(vector, fieldName, false);
56-
}
57-
58-
/**
59-
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
60-
* query vector and field for documents.
61-
*
62-
* @param vector the query vector
63-
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
64-
* @param useFullPrecision uses full precision raw vectors for similarity computation if true,
65-
* otherwise the configured vectors reader is used, which may be quantized or full precision.
66-
*/
67-
public FloatVectorSimilarityValuesSource(
68-
float[] vector, String fieldName, boolean useFullPrecision) {
6935
super(fieldName);
7036
this.queryVector = vector;
71-
this.useFullPrecision = useFullPrecision;
7237
}
7338

7439
@Override
@@ -78,35 +43,7 @@ public VectorScorer getScorer(LeafReaderContext ctx) throws IOException {
7843
FloatVectorValues.checkField(ctx.reader(), fieldName);
7944
return null;
8045
}
81-
final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName);
82-
if (fi.getVectorDimension() != queryVector.length) {
83-
throw new IllegalArgumentException(
84-
"Query vector dimension does not match field dimension: "
85-
+ queryVector.length
86-
+ " != "
87-
+ fi.getVectorDimension());
88-
}
89-
90-
// default vector scorer
91-
if (useFullPrecision == false) {
92-
return vectorValues.scorer(queryVector);
93-
}
94-
95-
final VectorSimilarityFunction vectorSimilarityFunction = fi.getVectorSimilarityFunction();
96-
return new VectorScorer() {
97-
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
98-
99-
@Override
100-
public float score() throws IOException {
101-
return vectorSimilarityFunction.compare(
102-
queryVector, vectorValues.vectorValue(iterator.index()));
103-
}
104-
105-
@Override
106-
public DocIdSetIterator iterator() {
107-
return iterator;
108-
}
109-
};
46+
return vectorValues.scorer(queryVector);
11047
}
11148

11249
@Override
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import java.io.IOException;
21+
import java.util.Arrays;
22+
import java.util.Objects;
23+
import org.apache.lucene.index.FieldInfo;
24+
import org.apache.lucene.index.FloatVectorValues;
25+
import org.apache.lucene.index.KnnVectorValues;
26+
import org.apache.lucene.index.LeafReaderContext;
27+
import org.apache.lucene.index.VectorSimilarityFunction;
28+
29+
/**
30+
* A {@link DoubleValuesSource} that computes vector similarity between a query vector and raw full
31+
* precision vectors indexed in provided {@link org.apache.lucene.document.KnnFloatVectorField} in
32+
* documents.
33+
*/
34+
public class FullPrecisionFloatVectorSimilarityValuesSource extends DoubleValuesSource {
35+
36+
private final float[] queryVector;
37+
private final String fieldName;
38+
private VectorSimilarityFunction vectorSimilarityFunction;
39+
40+
/**
41+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
42+
* query vector and field for documents.
43+
*
44+
* @param vector the query vector
45+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
46+
* @param vectorSimilarityFunction the vector similarity function to use
47+
*/
48+
public FullPrecisionFloatVectorSimilarityValuesSource(
49+
float[] vector, String fieldName, VectorSimilarityFunction vectorSimilarityFunction) {
50+
this.queryVector = vector;
51+
this.fieldName = fieldName;
52+
this.vectorSimilarityFunction = vectorSimilarityFunction;
53+
}
54+
55+
/**
56+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
57+
* query vector and field for documents. Uses the configured vector similarity function for the
58+
* field.
59+
*
60+
* @param vector the query vector
61+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
62+
*/
63+
public FullPrecisionFloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
64+
this(vector, fieldName, null);
65+
}
66+
67+
/** Sugar to fetch full precision similarity score values */
68+
public DoubleValues getSimilarityScores(LeafReaderContext ctx) throws IOException {
69+
return getValues(ctx, null);
70+
}
71+
72+
@Override
73+
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
74+
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
75+
if (vectorValues == null) {
76+
FloatVectorValues.checkField(ctx.reader(), fieldName);
77+
return DoubleValues.EMPTY;
78+
}
79+
final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName);
80+
if (fi.getVectorDimension() != queryVector.length) {
81+
throw new IllegalArgumentException(
82+
"Query vector dimension does not match field dimension: "
83+
+ queryVector.length
84+
+ " != "
85+
+ fi.getVectorDimension());
86+
}
87+
88+
if (vectorSimilarityFunction == null) {
89+
this.vectorSimilarityFunction = fi.getVectorSimilarityFunction();
90+
}
91+
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
92+
return new DoubleValues() {
93+
@Override
94+
public double doubleValue() throws IOException {
95+
return vectorSimilarityFunction.compare(
96+
queryVector, vectorValues.vectorValue(iterator.index()));
97+
}
98+
99+
@Override
100+
public boolean advanceExact(int doc) throws IOException {
101+
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
102+
}
103+
};
104+
}
105+
106+
@Override
107+
public boolean needsScores() {
108+
return false;
109+
}
110+
111+
@Override
112+
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
113+
return this;
114+
}
115+
116+
@Override
117+
public int hashCode() {
118+
return Objects.hash(fieldName, Arrays.hashCode(queryVector), vectorSimilarityFunction);
119+
}
120+
121+
@Override
122+
public boolean equals(Object obj) {
123+
if (this == obj) return true;
124+
if (obj == null || getClass() != obj.getClass()) return false;
125+
FullPrecisionFloatVectorSimilarityValuesSource other =
126+
(FullPrecisionFloatVectorSimilarityValuesSource) obj;
127+
return Objects.equals(fieldName, other.fieldName)
128+
&& Objects.equals(vectorSimilarityFunction, other.vectorSimilarityFunction)
129+
&& Arrays.equals(queryVector, other.queryVector);
130+
}
131+
132+
@Override
133+
public String toString() {
134+
return "FullPrecisionFloatVectorSimilarityValuesSource(fieldName="
135+
+ fieldName
136+
+ " vectorSimilarityFunction="
137+
+ vectorSimilarityFunction.name()
138+
+ " queryVector="
139+
+ Arrays.toString(queryVector)
140+
+ ")";
141+
}
142+
143+
@Override
144+
public boolean isCacheable(LeafReaderContext ctx) {
145+
return true;
146+
}
147+
}

0 commit comments

Comments
 (0)