Skip to content

Commit bd35649

Browse files
authored
Add ES93HnswVectorsFormat (#136474)
A generic HNSW vectors format that supports directIO and bfloat16
1 parent a8449b2 commit bd35649

File tree

7 files changed

+214
-3
lines changed

7 files changed

+214
-3
lines changed

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,7 @@
466466
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat,
467467
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat,
468468
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
469+
org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat,
469470
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;
470471

471472
provides org.apache.lucene.codecs.Codec

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BFloat16FlatVectorsWriter.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
2424
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
26-
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
2726
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
2827
import org.apache.lucene.index.DocsWithFieldSet;
2928
import org.apache.lucene.index.FieldInfo;
@@ -250,7 +249,7 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(FieldInfo fieldI
250249
final IndexInput finalVectorDataInput = vectorDataInput;
251250
final RandomVectorScorerSupplier randomVectorScorerSupplier = vectorsScorer.getRandomVectorScorerSupplier(
252251
fieldInfo.getVectorSimilarityFunction(),
253-
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
252+
new OffHeapBFloat16VectorValues.DenseOffHeapVectorValues(
254253
fieldInfo.getVectorDimension(),
255254
docsWithField.cardinality(),
256255
finalVectorDataInput,
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.codecs.KnnVectorsReader;
13+
import org.apache.lucene.codecs.KnnVectorsWriter;
14+
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
15+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
16+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
17+
import org.apache.lucene.index.SegmentReadState;
18+
import org.apache.lucene.index.SegmentWriteState;
19+
import org.elasticsearch.index.codec.vectors.AbstractHnswVectorsFormat;
20+
21+
import java.io.IOException;
22+
import java.util.concurrent.ExecutorService;
23+
24+
public class ES93HnswVectorsFormat extends AbstractHnswVectorsFormat {
25+
26+
static final String NAME = "ES93HnswVectorsFormat";
27+
28+
private final FlatVectorsFormat flatVectorsFormat;
29+
30+
public ES93HnswVectorsFormat() {
31+
super(NAME);
32+
flatVectorsFormat = new ES93GenericFlatVectorsFormat();
33+
}
34+
35+
public ES93HnswVectorsFormat(boolean bfloat16, boolean useDirectIO) {
36+
super(NAME);
37+
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
38+
}
39+
40+
public ES93HnswVectorsFormat(int maxConn, int beamWidth, boolean bfloat16, boolean useDirectIO) {
41+
super(NAME, maxConn, beamWidth);
42+
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
43+
}
44+
45+
public ES93HnswVectorsFormat(
46+
int maxConn,
47+
int beamWidth,
48+
boolean bfloat16,
49+
boolean useDirectIO,
50+
int numMergeWorkers,
51+
ExecutorService mergeExec
52+
) {
53+
super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec);
54+
flatVectorsFormat = new ES93GenericFlatVectorsFormat(bfloat16, useDirectIO);
55+
}
56+
57+
@Override
58+
protected FlatVectorsFormat flatVectorsFormat() {
59+
return flatVectorsFormat;
60+
}
61+
62+
@Override
63+
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
64+
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
65+
}
66+
67+
@Override
68+
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
69+
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
70+
}
71+
}

server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsForma
1010
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat
1111
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat
1212
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat
13+
org.elasticsearch.index.codec.vectors.es93.ES93HnswVectorsFormat
1314
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat

server/src/test/java/org/elasticsearch/index/codec/vectors/BaseHnswBFloat16VectorsFormatTestCase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ public void testSingleVectorCase() throws Exception {
112112
assertEquals(1, td.totalHits.value());
113113
assertThat(td.scoreDocs[0].score, greaterThanOrEqualTo(0f));
114114
// When it's the only vector in a segment, the score should be very close to the true score
115-
assertEquals(trueScore, td.scoreDocs[0].score, 0.01f);
115+
assertEquals(trueScore, td.scoreDocs[0].score, trueScore / 100);
116116
}
117117
}
118118
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.codecs.KnnVectorsFormat;
13+
import org.apache.lucene.store.Directory;
14+
import org.elasticsearch.index.codec.vectors.BFloat16;
15+
import org.elasticsearch.index.codec.vectors.BaseHnswBFloat16VectorsFormatTestCase;
16+
17+
import java.io.IOException;
18+
import java.util.Locale;
19+
import java.util.concurrent.ExecutorService;
20+
21+
import static java.lang.String.format;
22+
import static org.hamcrest.Matchers.aMapWithSize;
23+
import static org.hamcrest.Matchers.allOf;
24+
import static org.hamcrest.Matchers.hasEntry;
25+
import static org.hamcrest.Matchers.hasToString;
26+
import static org.hamcrest.Matchers.is;
27+
import static org.hamcrest.Matchers.oneOf;
28+
29+
public class ES93HnswBFloat16VectorsFormatTests extends BaseHnswBFloat16VectorsFormatTestCase {
30+
31+
@Override
32+
protected KnnVectorsFormat createFormat() {
33+
return new ES93HnswVectorsFormat(true, random().nextBoolean());
34+
}
35+
36+
@Override
37+
protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) {
38+
return new ES93HnswVectorsFormat(maxConn, beamWidth, true, random().nextBoolean());
39+
}
40+
41+
@Override
42+
protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) {
43+
return new ES93HnswVectorsFormat(maxConn, beamWidth, true, random().nextBoolean(), numMergeWorkers, service);
44+
}
45+
46+
public void testToString() {
47+
String expected = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)";
48+
expected = format(Locale.ROOT, expected, "ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=%s)");
49+
expected = format(
50+
Locale.ROOT,
51+
expected,
52+
"ES93BFloat16FlatVectorsFormat(name=ES93BFloat16FlatVectorsFormat, flatVectorScorer=%s())"
53+
);
54+
String defaultScorer = format(Locale.ROOT, expected, "DefaultFlatVectorScorer");
55+
String memSegScorer = format(Locale.ROOT, expected, "Lucene99MemorySegmentFlatVectorsScorer");
56+
57+
KnnVectorsFormat format = createFormat(10, 20, 1, null);
58+
assertThat(format, hasToString(is(oneOf(defaultScorer, memSegScorer))));
59+
}
60+
61+
public void testSimpleOffHeapSize() throws IOException {
62+
float[] vector = randomVector(random().nextInt(12, 500));
63+
try (Directory dir = newDirectory()) {
64+
testSimpleOffHeapSize(
65+
dir,
66+
newIndexWriterConfig(),
67+
vector,
68+
allOf(aMapWithSize(2), hasEntry("vec", (long) vector.length * BFloat16.BYTES), hasEntry("vex", 1L))
69+
);
70+
}
71+
}
72+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.codecs.KnnVectorsFormat;
13+
import org.apache.lucene.store.Directory;
14+
import org.elasticsearch.index.codec.vectors.BaseHnswVectorsFormatTestCase;
15+
16+
import java.io.IOException;
17+
import java.util.Locale;
18+
import java.util.concurrent.ExecutorService;
19+
20+
import static java.lang.String.format;
21+
import static org.hamcrest.Matchers.aMapWithSize;
22+
import static org.hamcrest.Matchers.allOf;
23+
import static org.hamcrest.Matchers.hasEntry;
24+
import static org.hamcrest.Matchers.hasToString;
25+
import static org.hamcrest.Matchers.is;
26+
import static org.hamcrest.Matchers.oneOf;
27+
28+
public class ES93HnswVectorsFormatTests extends BaseHnswVectorsFormatTestCase {
29+
30+
@Override
31+
protected KnnVectorsFormat createFormat() {
32+
return new ES93HnswVectorsFormat(false, random().nextBoolean());
33+
}
34+
35+
@Override
36+
protected KnnVectorsFormat createFormat(int maxConn, int beamWidth) {
37+
return new ES93HnswVectorsFormat(maxConn, beamWidth, false, random().nextBoolean());
38+
}
39+
40+
@Override
41+
protected KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service) {
42+
return new ES93HnswVectorsFormat(maxConn, beamWidth, false, random().nextBoolean(), numMergeWorkers, service);
43+
}
44+
45+
public void testToString() {
46+
String expected = "ES93HnswVectorsFormat(name=ES93HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=%s)";
47+
expected = format(Locale.ROOT, expected, "ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat, format=%s)");
48+
expected = format(Locale.ROOT, expected, "Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer=%s())");
49+
String defaultScorer = format(Locale.ROOT, expected, "DefaultFlatVectorScorer");
50+
String memSegScorer = format(Locale.ROOT, expected, "Lucene99MemorySegmentFlatVectorsScorer");
51+
52+
KnnVectorsFormat format = createFormat(10, 20, 1, null);
53+
assertThat(format, hasToString(is(oneOf(defaultScorer, memSegScorer))));
54+
}
55+
56+
public void testSimpleOffHeapSize() throws IOException {
57+
float[] vector = randomVector(random().nextInt(12, 500));
58+
try (Directory dir = newDirectory()) {
59+
testSimpleOffHeapSize(
60+
dir,
61+
newIndexWriterConfig(),
62+
vector,
63+
allOf(aMapWithSize(2), hasEntry("vec", (long) vector.length * Float.BYTES), hasEntry("vex", 1L))
64+
);
65+
}
66+
}
67+
}

0 commit comments

Comments
 (0)