3030import org .apache .paimon .types .ArrayType ;
3131import org .apache .paimon .types .DataType ;
3232import org .apache .paimon .types .FloatType ;
33+ import org .apache .paimon .types .TinyIntType ;
3334
3435import org .junit .jupiter .api .AfterEach ;
3536import org .junit .jupiter .api .BeforeEach ;
3940import java .io .IOException ;
4041import java .io .OutputStream ;
4142import java .util .ArrayList ;
43+ import java .util .Arrays ;
4244import java .util .List ;
4345import java .util .Random ;
4446import java .util .UUID ;
@@ -104,9 +106,7 @@ public void testDifferentSimilarityFunctions() throws IOException {
104106 new VectorGlobalIndexWriter (fileWriter , vectorType , options );
105107
106108 List <float []> testVectors = generateRandomVectors (numVectors , dimension );
107- for (int i = 0 ; i < numVectors ; i ++) {
108- writer .write (new FloatVectorIndex (i , testVectors .get (i )));
109- }
109+ testVectors .forEach (writer ::write );
110110
111111 List <GlobalIndexWriter .ResultEntry > results = writer .finish ();
112112 assertThat (results ).hasSize (1 );
@@ -142,9 +142,7 @@ public void testDifferentDimensions() throws IOException {
142142
143143 int numVectors = 10 ;
144144 List <float []> testVectors = generateRandomVectors (numVectors , dimension );
145- for (int i = 0 ; i < numVectors ; i ++) {
146- writer .write (new FloatVectorIndex (i , testVectors .get (i )));
147- }
145+ testVectors .forEach (writer ::write );
148146
149147 List <GlobalIndexWriter .ResultEntry > results = writer .finish ();
150148 assertThat (results ).hasSize (1 );
@@ -177,7 +175,7 @@ public void testDimensionMismatch() throws IOException {
177175
178176 // Try to write vector with wrong dimension
179177 float [] wrongDimVector = new float [32 ]; // Wrong dimension
180- assertThatThrownBy (() -> writer .write (new FloatVectorIndex ( 0 , wrongDimVector ) ))
178+ assertThatThrownBy (() -> writer .write (wrongDimVector ))
181179 .isInstanceOf (IllegalArgumentException .class )
182180 .hasMessageContaining ("dimension mismatch" );
183181 }
@@ -186,7 +184,8 @@ public void testDimensionMismatch() throws IOException {
186184 public void testFloatVectorIndexEndToEnd () throws IOException {
187185 int dimension = 2 ;
188186 Options options = createDefaultOptions (dimension );
189- options .setInteger ("vector.size-per-index" , 3 );
187+ int sizePerIndex = 3 ;
188+ options .setInteger ("vector.size-per-index" , sizePerIndex );
190189
191190 float [][] vectors =
192191 new float [][] {
@@ -197,16 +196,15 @@ public void testFloatVectorIndexEndToEnd() throws IOException {
197196 GlobalIndexFileWriter fileWriter = createFileWriter (indexPath );
198197 VectorGlobalIndexWriter writer =
199198 new VectorGlobalIndexWriter (fileWriter , vectorType , options );
200- for (int i = 0 ; i < vectors .length ; i ++) {
201- writer .write (new FloatVectorIndex (i , vectors [i ]));
202- }
199+ Arrays .stream (vectors ).forEach (writer ::write );
203200
204201 List <GlobalIndexWriter .ResultEntry > results = writer .finish ();
205202 assertThat (results ).hasSize (2 );
206203
207204 GlobalIndexFileReader fileReader = createFileReader (indexPath );
208205 List <GlobalIndexIOMeta > metas = new ArrayList <>();
209- for (GlobalIndexWriter .ResultEntry result : results ) {
206+ for (int i = 0 ; i < results .size (); i ++) {
207+ GlobalIndexWriter .ResultEntry result = results .get (i );
210208 metas .add (
211209 new GlobalIndexIOMeta (
212210 result .fileName (),
@@ -218,13 +216,60 @@ public void testFloatVectorIndexEndToEnd() throws IOException {
218216 try (VectorGlobalIndexReader reader = new VectorGlobalIndexReader (fileReader , metas )) {
219217 GlobalIndexResult result = reader .search (vectors [0 ], 1 );
220218 assertThat (result .results ().getLongCardinality ()).isEqualTo (1 );
221- assertThat (containsRowId (result , 0 )).isTrue ();
219+ assertThat (containsRowId (result , 1 )).isTrue ();
222220
223221 float [] queryVector = new float [] {0.85f , 0.15f };
224222 result = reader .search (queryVector , 2 );
225223 assertThat (result .results ().getLongCardinality ()).isEqualTo (2 );
224+ assertThat (containsRowId (result , 2 )).isTrue ();
225+ assertThat (containsRowId (result , 4 )).isTrue ();
226+ }
227+ }
228+
229+ @ Test
230+ public void testByteVectorIndexEndToEnd () throws IOException {
231+ int dimension = 2 ;
232+ Options options = createDefaultOptions (dimension );
233+ int sizePerIndex = 3 ;
234+ options .setInteger ("vector.size-per-index" , sizePerIndex );
235+
236+ byte [][] vectors =
237+ new byte [][] {
238+ new byte [] {100 , 0 }, new byte [] {95 , 10 }, new byte [] {10 , 95 },
239+ new byte [] {98 , 5 }, new byte [] {0 , 100 }, new byte [] {5 , 98 }
240+ };
241+
242+ DataType byteVectorType = new ArrayType (new TinyIntType ());
243+ GlobalIndexFileWriter fileWriter = createFileWriter (indexPath );
244+ VectorGlobalIndexWriter writer =
245+ new VectorGlobalIndexWriter (fileWriter , byteVectorType , options );
246+ Arrays .stream (vectors ).forEach (writer ::write );
247+
248+ List <GlobalIndexWriter .ResultEntry > results = writer .finish ();
249+ assertThat (results ).hasSize (2 );
250+
251+ GlobalIndexFileReader fileReader = createFileReader (indexPath );
252+ List <GlobalIndexIOMeta > metas = new ArrayList <>();
253+ for (int i = 0 ; i < results .size (); i ++) {
254+ GlobalIndexWriter .ResultEntry result = results .get (i );
255+ metas .add (
256+ new GlobalIndexIOMeta (
257+ result .fileName (),
258+ fileIO .getFileSize (new Path (indexPath , result .fileName ())),
259+ result .rowRange (),
260+ result .meta ()));
261+ }
262+
263+ try (VectorGlobalIndexReader reader = new VectorGlobalIndexReader (fileReader , metas )) {
264+ GlobalIndexResult result = reader .search (vectors [0 ], 1 );
265+ assertThat (result .results ().getLongCardinality ()).isEqualTo (1 );
226266 assertThat (containsRowId (result , 1 )).isTrue ();
227- assertThat (containsRowId (result , 3 )).isTrue ();
267+
268+ byte [] queryVector = new byte [] {85 , 15 };
269+ result = reader .search (queryVector , 2 );
270+ assertThat (result .results ().getLongCardinality ()).isEqualTo (2 );
271+ assertThat (containsRowId (result , 2 )).isTrue ();
272+ assertThat (containsRowId (result , 4 )).isTrue ();
228273 }
229274 }
230275
0 commit comments