29
29
test_groundtruth_l2 , \
30
30
test_groundtruth_mip , \
31
31
test_number_of_vectors , \
32
- test_dimensions
32
+ test_dimensions , \
33
+ test_get_distance
33
34
34
35
class FlatTester (unittest .TestCase ):
35
36
"""
@@ -54,7 +55,7 @@ def _loaders(self, file: svs.VectorDataLoader):
54
55
}),
55
56
]
56
57
57
- def _do_test (self , flat , queries , groundtruth , expected_recall = 1.0 ):
58
+ def _do_test (self , flat , queries , groundtruth , distance , data = svs . read_vecs ( test_data_vecs ), expected_recall = 1.0 , test_distance = True ):
58
59
"""
59
60
Perform a series of tests on a Flat index to test its conformance to expectations.
60
61
Parameters:
@@ -67,6 +68,9 @@ def _do_test(self, flat, queries, groundtruth, expected_recall = 1.0):
67
68
- Results of `search` are within acceptable margins of the groundtruth.
68
69
- The number of threads can be changed with an observable side-effect.
69
70
"""
71
+ # Test get distance
72
+ test_get_distance (flat , distance , data , test_distance )
73
+
70
74
# Data interface
71
75
self .assertEqual (flat .size , test_number_of_vectors )
72
76
self .assertEqual (flat .dimensions , test_dimensions )
@@ -117,7 +121,7 @@ def _do_test_from_file(self, distance: svs.DistanceType, queries, groundtruth):
117
121
svs .VectorDataLoader (
118
122
test_data_svs , svs .DataType .float32 , dims = test_data_dims
119
123
)
120
- );
124
+ )
121
125
for loader , recall in loaders :
122
126
index = svs .Flat (
123
127
loader ,
@@ -126,7 +130,7 @@ def _do_test_from_file(self, distance: svs.DistanceType, queries, groundtruth):
126
130
)
127
131
128
132
self .assertEqual (index .num_threads , num_threads )
129
- self ._do_test (index , queries , groundtruth , expected_recall = recall [distance ])
133
+ self ._do_test (index , queries , groundtruth , distance , expected_recall = recall [distance ])
130
134
131
135
def test_from_file (self ):
132
136
"""
@@ -154,21 +158,22 @@ def test_from_array(self):
154
158
# Test `float32`
155
159
print ("Flat, From Array, Float32" )
156
160
flat = svs .Flat (data_f32 , svs .DistanceType .L2 )
157
- self ._do_test (flat , queries_f32 , groundtruth )
161
+ self ._do_test (flat , queries_f32 , groundtruth , svs . DistanceType . L2 , data_f32 )
158
162
159
163
# Test `float16`
160
164
print ("Flat, From Array, Float16" )
161
165
data_f16 = data_f32 .astype ('float16' )
162
166
queries_f16 = queries_f32 .astype ('float16' )
163
167
flat = svs .Flat (data_f16 , svs .DistanceType .L2 )
164
- self ._do_test (flat , queries_f16 , groundtruth )
168
+ # Do not test get distance for fp16 data as py_contiguous_array_t does not support it
169
+ self ._do_test (flat , queries_f16 , groundtruth , svs .DistanceType .L2 , data_f16 , test_distance = False )
165
170
166
171
# Test `int8`
167
172
print ("Flat, From Array, Int8" )
168
173
data_i8 = data_f32 .astype ('int8' )
169
174
queries_i8 = queries_f32 .astype ('int8' )
170
175
flat = svs .Flat (data_i8 , svs .DistanceType .L2 )
171
- self ._do_test (flat , queries_i8 , groundtruth )
176
+ self ._do_test (flat , queries_i8 , groundtruth , svs . DistanceType . L2 , data = data_i8 )
172
177
173
178
# Test 'uint8'
174
179
# The dataset is stored as values that can be encoded as `int8`.
@@ -178,4 +183,4 @@ def test_from_array(self):
178
183
data_u8 = (data_f32 + 128 ).astype ('uint8' )
179
184
queries_u8 = (queries_f32 + 128 ).astype ('uint8' )
180
185
flat = svs .Flat (data_u8 , svs .DistanceType .L2 )
181
- self ._do_test (flat , queries_u8 , groundtruth )
186
+ self ._do_test (flat , queries_u8 , groundtruth , svs . DistanceType . L2 , data = data_u8 )
0 commit comments