Skip to content

Commit 0a90b51

Browse files
committed
add vector index create by data type
1 parent f1b6011 commit 0a90b51

File tree

6 files changed

+135
-37
lines changed

6 files changed

+135
-37
lines changed

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/ByteVectorIndex.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public ByteVectorIndex(long rowId, byte[] vector) {
3333
}
3434

3535
@Override
36-
public long rowId() {
36+
public long id() {
3737
return rowId;
3838
}
3939

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/FloatVectorIndex.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public float[] vector() {
3838
}
3939

4040
@Override
41-
public long rowId() {
41+
public long id() {
4242
return rowId;
4343
}
4444

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/VectorGlobalIndexWriter.java

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.paimon.options.Options;
2424
import org.apache.paimon.types.ArrayType;
2525
import org.apache.paimon.types.DataType;
26+
import org.apache.paimon.types.DataTypes;
2627
import org.apache.paimon.utils.Range;
2728

2829
import org.apache.lucene.codecs.KnnVectorsFormat;
@@ -34,6 +35,7 @@
3435
import org.apache.lucene.index.VectorSimilarityFunction;
3536
import org.apache.lucene.store.Directory;
3637
import org.apache.lucene.store.IOContext;
38+
import org.apache.lucene.store.IndexInput;
3739

3840
import java.io.BufferedOutputStream;
3941
import java.io.IOException;
@@ -43,8 +45,6 @@
4345
import java.util.ArrayList;
4446
import java.util.List;
4547

46-
import static org.apache.paimon.utils.Preconditions.checkArgument;
47-
4848
/**
4949
* Vector global index writer using Apache Lucene 9.x.
5050
*
@@ -53,19 +53,22 @@
5353
*/
5454
public class VectorGlobalIndexWriter implements GlobalIndexWriter {
5555

56+
private static final DataType FLOAT_ARRAY_TYPE = new ArrayType(DataTypes.FLOAT());
57+
private static final DataType BYTE_ARRAY_TYPE = new ArrayType(DataTypes.TINYINT());
58+
5659
private final GlobalIndexFileWriter fileWriter;
5760
private final VectorIndexOptions vectorOptions;
5861
private final VectorSimilarityFunction similarityFunction;
5962
private final int sizePerIndex;
63+
private final VectorIndexFactory vectorIndexFactory;
6064

65+
private long count = 0;
6166
private final List<VectorIndex> vectorIndices;
6267
private final List<ResultEntry> results;
6368

6469
public VectorGlobalIndexWriter(
6570
GlobalIndexFileWriter fileWriter, DataType fieldType, Options options) {
66-
checkArgument(
67-
fieldType instanceof ArrayType,
68-
"Vector field type must be ARRAY, but was: " + fieldType);
71+
this.vectorIndexFactory = VectorIndexFactory.init(fieldType);
6972
this.fileWriter = fileWriter;
7073
this.vectorIndices = new ArrayList<>();
7174
this.results = new ArrayList<>();
@@ -76,15 +79,8 @@ public VectorGlobalIndexWriter(
7679

7780
@Override
7881
public void write(Object key) {
79-
VectorIndex index;
80-
if (key instanceof FloatVectorIndex) {
81-
index = (FloatVectorIndex) key;
82-
} else if (key instanceof ByteVectorIndex) {
83-
index = (ByteVectorIndex) key;
84-
} else {
85-
throw new IllegalArgumentException(
86-
"Unsupported index type: " + key.getClass().getName());
87-
}
82+
count++;
83+
VectorIndex index = vectorIndexFactory.create(count, key);
8884
index.checkDimension(vectorOptions.dimension());
8985
vectorIndices.add(index);
9086
if (vectorIndices.size() >= sizePerIndex) {
@@ -119,8 +115,8 @@ private void flush() throws IOException {
119115
this.vectorOptions.writeBufferSize(),
120116
out);
121117
}
122-
long minRowIdInBatch = vectorIndices.get(0).rowId();
123-
long maxRowIdInBatch = vectorIndices.get(vectorIndices.size() - 1).rowId();
118+
long minRowIdInBatch = vectorIndices.get(0).id();
119+
long maxRowIdInBatch = vectorIndices.get(vectorIndices.size() - 1).id();
124120
results.add(ResultEntry.of(fileName, null, new Range(minRowIdInBatch, maxRowIdInBatch)));
125121
vectorIndices.clear();
126122
}
@@ -176,8 +172,7 @@ private void serializeDirectory(Directory directory, OutputStream out) throws IO
176172
long fileLength = directory.fileLength(fileName);
177173
out.write(ByteBuffer.allocate(8).putLong(fileLength).array());
178174

179-
try (org.apache.lucene.store.IndexInput input =
180-
directory.openInput(fileName, IOContext.DEFAULT)) {
175+
try (IndexInput input = directory.openInput(fileName, IOContext.DEFAULT)) {
181176
byte[] buffer = new byte[32 * 1024];
182177
long remaining = fileLength;
183178

paimon-lucene/src/main/java/org/apache/paimon/lucene/index/VectorIndex.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
public abstract class VectorIndex<T> {
2727

2828
public static final String VECTOR_FIELD = "vector";
29-
public static final String ROW_ID_FIELD = "rowId";
29+
public static final String ROW_ID_FIELD = "id";
3030

31-
public abstract long rowId();
31+
public abstract long id();
3232

3333
public abstract long dimension();
3434

@@ -37,7 +37,7 @@ public abstract class VectorIndex<T> {
3737
public abstract IndexableField indexableField(VectorSimilarityFunction similarityFunction);
3838

3939
public StoredField rowIdStoredField() {
40-
return new StoredField(ROW_ID_FIELD, rowId());
40+
return new StoredField(ROW_ID_FIELD, id());
4141
}
4242

4343
public void checkDimension(int dimension) {
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.paimon.lucene.index;
20+
21+
import org.apache.paimon.types.ArrayType;
22+
import org.apache.paimon.types.DataType;
23+
import org.apache.paimon.types.FloatType;
24+
import org.apache.paimon.types.TinyIntType;
25+
26+
/** Factory for creating vector index instances based on data type. */
27+
public abstract class VectorIndexFactory {
28+
29+
public static VectorIndexFactory init(DataType dataType) {
30+
if (dataType instanceof ArrayType
31+
&& ((ArrayType) dataType).getElementType() instanceof FloatType) {
32+
return new FloatVectorIndexFactory();
33+
} else if (dataType instanceof ArrayType
34+
&& ((ArrayType) dataType).getElementType() instanceof TinyIntType) {
35+
return new ByteVectorIndexFactory();
36+
} else {
37+
throw new IllegalArgumentException("Unsupported data type: " + dataType);
38+
}
39+
}
40+
41+
public abstract VectorIndex create(long rowId, Object vector);
42+
43+
/** Factory for creating FloatVectorIndex instances. */
44+
public static class FloatVectorIndexFactory extends VectorIndexFactory {
45+
@Override
46+
public VectorIndex create(long rowId, Object vector) {
47+
return new FloatVectorIndex(rowId, (float[]) vector);
48+
}
49+
}
50+
51+
/** Factory for creating FloatVectorIndex instances. */
52+
public static class ByteVectorIndexFactory extends VectorIndexFactory {
53+
@Override
54+
public VectorIndex create(long rowId, Object vector) {
55+
return new ByteVectorIndex(rowId, (byte[]) vector);
56+
}
57+
}
58+
}

paimon-lucene/src/test/java/org/apache/paimon/lucene/index/VectorGlobalIndexTest.java

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.paimon.types.ArrayType;
3131
import org.apache.paimon.types.DataType;
3232
import org.apache.paimon.types.FloatType;
33+
import org.apache.paimon.types.TinyIntType;
3334

3435
import org.junit.jupiter.api.AfterEach;
3536
import org.junit.jupiter.api.BeforeEach;
@@ -39,6 +40,7 @@
3940
import java.io.IOException;
4041
import java.io.OutputStream;
4142
import java.util.ArrayList;
43+
import java.util.Arrays;
4244
import java.util.List;
4345
import java.util.Random;
4446
import java.util.UUID;
@@ -104,9 +106,7 @@ public void testDifferentSimilarityFunctions() throws IOException {
104106
new VectorGlobalIndexWriter(fileWriter, vectorType, options);
105107

106108
List<float[]> testVectors = generateRandomVectors(numVectors, dimension);
107-
for (int i = 0; i < numVectors; i++) {
108-
writer.write(new FloatVectorIndex(i, testVectors.get(i)));
109-
}
109+
testVectors.forEach(writer::write);
110110

111111
List<GlobalIndexWriter.ResultEntry> results = writer.finish();
112112
assertThat(results).hasSize(1);
@@ -142,9 +142,7 @@ public void testDifferentDimensions() throws IOException {
142142

143143
int numVectors = 10;
144144
List<float[]> testVectors = generateRandomVectors(numVectors, dimension);
145-
for (int i = 0; i < numVectors; i++) {
146-
writer.write(new FloatVectorIndex(i, testVectors.get(i)));
147-
}
145+
testVectors.forEach(writer::write);
148146

149147
List<GlobalIndexWriter.ResultEntry> results = writer.finish();
150148
assertThat(results).hasSize(1);
@@ -177,7 +175,7 @@ public void testDimensionMismatch() throws IOException {
177175

178176
// Try to write vector with wrong dimension
179177
float[] wrongDimVector = new float[32]; // Wrong dimension
180-
assertThatThrownBy(() -> writer.write(new FloatVectorIndex(0, wrongDimVector)))
178+
assertThatThrownBy(() -> writer.write(wrongDimVector))
181179
.isInstanceOf(IllegalArgumentException.class)
182180
.hasMessageContaining("dimension mismatch");
183181
}
@@ -186,7 +184,8 @@ public void testDimensionMismatch() throws IOException {
186184
public void testFloatVectorIndexEndToEnd() throws IOException {
187185
int dimension = 2;
188186
Options options = createDefaultOptions(dimension);
189-
options.setInteger("vector.size-per-index", 3);
187+
int sizePerIndex = 3;
188+
options.setInteger("vector.size-per-index", sizePerIndex);
190189

191190
float[][] vectors =
192191
new float[][] {
@@ -197,16 +196,15 @@ public void testFloatVectorIndexEndToEnd() throws IOException {
197196
GlobalIndexFileWriter fileWriter = createFileWriter(indexPath);
198197
VectorGlobalIndexWriter writer =
199198
new VectorGlobalIndexWriter(fileWriter, vectorType, options);
200-
for (int i = 0; i < vectors.length; i++) {
201-
writer.write(new FloatVectorIndex(i, vectors[i]));
202-
}
199+
Arrays.stream(vectors).forEach(writer::write);
203200

204201
List<GlobalIndexWriter.ResultEntry> results = writer.finish();
205202
assertThat(results).hasSize(2);
206203

207204
GlobalIndexFileReader fileReader = createFileReader(indexPath);
208205
List<GlobalIndexIOMeta> metas = new ArrayList<>();
209-
for (GlobalIndexWriter.ResultEntry result : results) {
206+
for (int i = 0; i < results.size(); i++) {
207+
GlobalIndexWriter.ResultEntry result = results.get(i);
210208
metas.add(
211209
new GlobalIndexIOMeta(
212210
result.fileName(),
@@ -218,13 +216,60 @@ public void testFloatVectorIndexEndToEnd() throws IOException {
218216
try (VectorGlobalIndexReader reader = new VectorGlobalIndexReader(fileReader, metas)) {
219217
GlobalIndexResult result = reader.search(vectors[0], 1);
220218
assertThat(result.results().getLongCardinality()).isEqualTo(1);
221-
assertThat(containsRowId(result, 0)).isTrue();
219+
assertThat(containsRowId(result, 1)).isTrue();
222220

223221
float[] queryVector = new float[] {0.85f, 0.15f};
224222
result = reader.search(queryVector, 2);
225223
assertThat(result.results().getLongCardinality()).isEqualTo(2);
224+
assertThat(containsRowId(result, 2)).isTrue();
225+
assertThat(containsRowId(result, 4)).isTrue();
226+
}
227+
}
228+
229+
@Test
230+
public void testByteVectorIndexEndToEnd() throws IOException {
231+
int dimension = 2;
232+
Options options = createDefaultOptions(dimension);
233+
int sizePerIndex = 3;
234+
options.setInteger("vector.size-per-index", sizePerIndex);
235+
236+
byte[][] vectors =
237+
new byte[][] {
238+
new byte[] {100, 0}, new byte[] {95, 10}, new byte[] {10, 95},
239+
new byte[] {98, 5}, new byte[] {0, 100}, new byte[] {5, 98}
240+
};
241+
242+
DataType byteVectorType = new ArrayType(new TinyIntType());
243+
GlobalIndexFileWriter fileWriter = createFileWriter(indexPath);
244+
VectorGlobalIndexWriter writer =
245+
new VectorGlobalIndexWriter(fileWriter, byteVectorType, options);
246+
Arrays.stream(vectors).forEach(writer::write);
247+
248+
List<GlobalIndexWriter.ResultEntry> results = writer.finish();
249+
assertThat(results).hasSize(2);
250+
251+
GlobalIndexFileReader fileReader = createFileReader(indexPath);
252+
List<GlobalIndexIOMeta> metas = new ArrayList<>();
253+
for (int i = 0; i < results.size(); i++) {
254+
GlobalIndexWriter.ResultEntry result = results.get(i);
255+
metas.add(
256+
new GlobalIndexIOMeta(
257+
result.fileName(),
258+
fileIO.getFileSize(new Path(indexPath, result.fileName())),
259+
result.rowRange(),
260+
result.meta()));
261+
}
262+
263+
try (VectorGlobalIndexReader reader = new VectorGlobalIndexReader(fileReader, metas)) {
264+
GlobalIndexResult result = reader.search(vectors[0], 1);
265+
assertThat(result.results().getLongCardinality()).isEqualTo(1);
226266
assertThat(containsRowId(result, 1)).isTrue();
227-
assertThat(containsRowId(result, 3)).isTrue();
267+
268+
byte[] queryVector = new byte[] {85, 15};
269+
result = reader.search(queryVector, 2);
270+
assertThat(result.results().getLongCardinality()).isEqualTo(2);
271+
assertThat(containsRowId(result, 2)).isTrue();
272+
assertThat(containsRowId(result, 4)).isTrue();
228273
}
229274
}
230275

0 commit comments

Comments
 (0)