Skip to content

Commit 2c642d6

Browse files
authored
feature: get_distance function and test in dynamic_index.h (#102)
This PR solves issue #99 Added new get_distance function in dynamic_index.h to get distance between a vector in index and a normal vecor
1 parent e9f6dd8 commit 2c642d6

File tree

16 files changed

+455
-35
lines changed

16 files changed

+455
-35
lines changed

bindings/python/include/svs/python/vamana.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,29 @@ Calling this method should not affect recall.)"
213213
)"
214214
);
215215

216+
manager.def(
217+
"get_distance",
218+
[](Manager& self, size_t external_id, py_contiguous_array_t<float> query_array) {
219+
// Get raw pointer + size from the Python array
220+
const float* data_ptr = query_array.data();
221+
size_t n = query_array.size();
222+
std::vector<float> vec(data_ptr, data_ptr + n);
223+
return self.get_distance(external_id, vec);
224+
},
225+
pybind11::arg("external_id"),
226+
pybind11::arg("query_vector"),
227+
R"(
228+
Compute the distance between the stored vector at `external_id` and the provided `query_vector`.
229+
230+
Args:
231+
external_id: the external ID of the vector in the index
232+
query_vector: a 1-D contiguous array whose length must match the index dimensionality
233+
234+
Returns:
235+
float: the computed distance
236+
)"
237+
);
238+
216239
///// Experiemntal Interfaces
217240
add_experimental_calibration<svs::Float16>(manager);
218241
add_experimental_calibration<float>(manager);

bindings/python/src/flat.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,30 @@ queries-per-second).
274274
275275
See also: `svs.FlatSearchParameters`.)"
276276
);
277+
278+
for_standard_specializations([&flat]<typename Q, typename T, size_t N>() {
279+
flat.def(
280+
"get_distance",
281+
[](svs::Flat& self, size_t external_id, py_contiguous_array_t<T> query_array) {
282+
const T* data_ptr = query_array.data();
283+
size_t len = query_array.size();
284+
std::vector<T> vec(data_ptr, data_ptr + len);
285+
return self.get_distance(external_id, vec);
286+
},
287+
py::arg("external_id"),
288+
py::arg("query_vector"),
289+
R"(
290+
Compute the distance between the stored vector at `external_id` and the provided `query_vector`.
291+
292+
Args:
293+
external_id: the external ID of the vector in the index
294+
query_vector: a 1-D contiguous array whose length must match the index dimensionality
295+
296+
Returns:
297+
float: the computed distance
298+
)"
299+
);
300+
});
277301
}
278302

279303
} // namespace svs::python::flat

bindings/python/tests/common.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,53 @@ def test_threading(f, *args, validate = None, iters = 4, print_times = False):
123123
# For short lived processes, we generally see closer to a 3x speedup than a 4x
124124
# speedup when using 4 threads.
125125
testcase.assertTrue(1.3 * new_time < base_time)
126+
127+
def test_get_distance(index, distance, data = svs.read_vecs(test_data_vecs), test_distance = True):
128+
"""
129+
Test the get_distance method of an index by comparing its results with direct distance computation.
130+
131+
Arguments:
132+
index: The SVS index object with get_distance method
133+
distance: The distance type
134+
data: The dataset used to build the index
135+
test_distance: Whether to perform the distance test
136+
"""
137+
# Skip get_distance_test if flag is set
138+
if not test_distance:
139+
return
140+
141+
tolerance=1e-2
142+
query_id = 10
143+
index_id = 100
144+
dt = data.dtype
145+
query_vector_raw = np.array(data[query_id], dtype=dt)
146+
indexed_vector_raw = np.array(data[index_id], dtype=dt)
147+
index_distance = index.get_distance(index_id, query_vector_raw)
148+
# Up cast to avoid overflow
149+
query_vector = query_vector_raw.astype(np.float32)
150+
indexed_vector = indexed_vector_raw.astype(np.float32)
151+
152+
# Compute distance based on distance type
153+
if distance == svs.DistanceType.L2:
154+
expected_distance = np.sum((query_vector - indexed_vector) ** 2)
155+
elif distance == svs.DistanceType.MIP:
156+
expected_distance = np.dot(query_vector, indexed_vector)
157+
elif distance == svs.DistanceType.Cosine:
158+
qn = np.linalg.norm(query_vector)
159+
vn = np.linalg.norm(indexed_vector)
160+
if qn == 0 or vn == 0:
161+
expected_distance = 0.0
162+
else:
163+
expected_distance = (np.dot(query_vector, indexed_vector) / (qn * vn))
164+
else:
165+
raise ValueError(f"Unsupported DistanceType: {distance}")
166+
167+
relative_diff = abs((index_distance - expected_distance) / expected_distance)
168+
assert relative_diff < tolerance
169+
170+
# Test out of bounds ID
171+
try:
172+
index.get_distance(index_id + 99999, query_vector_raw)
173+
assert False, "Should have exception"
174+
except Exception as e:
175+
pass

bindings/python/tests/test_flat.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
test_groundtruth_l2, \
3030
test_groundtruth_mip, \
3131
test_number_of_vectors, \
32-
test_dimensions
32+
test_dimensions, \
33+
test_get_distance
3334

3435
class FlatTester(unittest.TestCase):
3536
"""
@@ -54,7 +55,7 @@ def _loaders(self, file: svs.VectorDataLoader):
5455
}),
5556
]
5657

57-
def _do_test(self, flat, queries, groundtruth, expected_recall = 1.0):
58+
def _do_test(self, flat, queries, groundtruth, distance, data = svs.read_vecs(test_data_vecs), expected_recall = 1.0, test_distance = True):
5859
"""
5960
Perform a series of tests on a Flat index to test its conformance to expectations.
6061
Parameters:
@@ -67,6 +68,9 @@ def _do_test(self, flat, queries, groundtruth, expected_recall = 1.0):
6768
- Results of `search` are within acceptable margins of the groundtruth.
6869
- The number of threads can be changed with an observable side-effect.
6970
"""
71+
# Test get distance
72+
test_get_distance(flat, distance, data, test_distance)
73+
7074
# Data interface
7175
self.assertEqual(flat.size, test_number_of_vectors)
7276
self.assertEqual(flat.dimensions, test_dimensions)
@@ -117,7 +121,7 @@ def _do_test_from_file(self, distance: svs.DistanceType, queries, groundtruth):
117121
svs.VectorDataLoader(
118122
test_data_svs, svs.DataType.float32, dims = test_data_dims
119123
)
120-
);
124+
)
121125
for loader, recall in loaders:
122126
index = svs.Flat(
123127
loader,
@@ -126,7 +130,7 @@ def _do_test_from_file(self, distance: svs.DistanceType, queries, groundtruth):
126130
)
127131

128132
self.assertEqual(index.num_threads, num_threads)
129-
self._do_test(index, queries, groundtruth, expected_recall = recall[distance])
133+
self._do_test(index, queries, groundtruth, distance, expected_recall = recall[distance])
130134

131135
def test_from_file(self):
132136
"""
@@ -154,21 +158,22 @@ def test_from_array(self):
154158
# Test `float32`
155159
print("Flat, From Array, Float32")
156160
flat = svs.Flat(data_f32, svs.DistanceType.L2)
157-
self._do_test(flat, queries_f32, groundtruth)
161+
self._do_test(flat, queries_f32, groundtruth, svs.DistanceType.L2, data_f32)
158162

159163
# Test `float16`
160164
print("Flat, From Array, Float16")
161165
data_f16 = data_f32.astype('float16')
162166
queries_f16 = queries_f32.astype('float16')
163167
flat = svs.Flat(data_f16, svs.DistanceType.L2)
164-
self._do_test(flat, queries_f16, groundtruth)
168+
# Do not test get distance for fp16 data as py_contiguous_array_t does not support it
169+
self._do_test(flat, queries_f16, groundtruth, svs.DistanceType.L2, data_f16, test_distance = False)
165170

166171
# Test `int8`
167172
print("Flat, From Array, Int8")
168173
data_i8 = data_f32.astype('int8')
169174
queries_i8 = queries_f32.astype('int8')
170175
flat = svs.Flat(data_i8, svs.DistanceType.L2)
171-
self._do_test(flat, queries_i8, groundtruth)
176+
self._do_test(flat, queries_i8, groundtruth, svs.DistanceType.L2, data=data_i8)
172177

173178
# Test 'uint8'
174179
# The dataset is stored as values that can be encoded as `int8`.
@@ -178,4 +183,4 @@ def test_from_array(self):
178183
data_u8 = (data_f32 + 128).astype('uint8')
179184
queries_u8 = (queries_f32 + 128).astype('uint8')
180185
flat = svs.Flat(data_u8, svs.DistanceType.L2)
181-
self._do_test(flat, queries_u8, groundtruth)
186+
self._do_test(flat, queries_u8, groundtruth, svs.DistanceType.L2, data=data_u8)

bindings/python/tests/test_vamana.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
test_vamana_reference, \
4040
test_number_of_vectors, \
4141
test_dimensions, \
42-
get_test_set
42+
get_test_set, \
43+
test_get_distance
4344

4445
from .dataset import UncompressedMatcher
4546

@@ -305,6 +306,9 @@ def _test_build(
305306
vamana = svs.Vamana.build(params, loader, distance, num_threads = num_threads)
306307
print(f"Building: {vamana.experimental_backend_string}")
307308

309+
# Test get distance
310+
test_get_distance(vamana, distance)
311+
308312
groundtruth_map = self._groundtruth_map()
309313
# Load the queries and groundtruth
310314
queries = svs.read_vecs(test_queries)

include/svs/index/flat/flat.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,39 @@ data::GetDatumAccessor svs_invoke(svs::tag_t<accessor>, const Data& SVS_UNUSED(d
7070
return data::GetDatumAccessor{};
7171
}
7272

73+
/////
74+
///// Distance
75+
/////
76+
struct ComputeDistanceType {
77+
template <typename Data, typename Distance, typename Query>
78+
double operator()(
79+
const Data& data, const Distance& distance, size_t internal_id, const Query& query
80+
) const {
81+
return svs_invoke(*this, data, distance, internal_id, query);
82+
}
83+
};
84+
// CPO for distance computation
85+
inline constexpr ComputeDistanceType get_distance_ext{};
86+
template <typename Data, typename Distance, typename Query>
87+
double svs_invoke(
88+
svs::tag_t<get_distance_ext>,
89+
const Data& data,
90+
const Distance& distance,
91+
size_t internal_id,
92+
const Query& query
93+
) {
94+
// Get distance
95+
auto dist_f = extensions::distance(data, distance);
96+
svs::distance::maybe_fix_argument(dist_f, query);
97+
98+
// Get the vector from the index
99+
auto indexed_span = data.get_datum(internal_id);
100+
101+
// Compute the distance using the appropriate distance function
102+
auto dist = svs::distance::compute(dist_f, query, indexed_span);
103+
104+
return static_cast<double>(dist);
105+
}
73106
} // namespace extensions
74107

75108
// The flat index is "special" because we wish to enable the `FlatIndex` to either:
@@ -455,6 +488,30 @@ class FlatIndex {
455488
/// @brief Return the current thread pool handle.
456489
///
457490
threads::ThreadPoolHandle& get_threadpool_handle() { return threadpool_; }
491+
492+
///// Distance
493+
494+
/// @brief Compute the distance between an external vector and a vector in the index.
495+
template <typename Query> double get_distance(size_t id, const Query& query) const {
496+
// Check if id is valid
497+
if (id >= size()) {
498+
throw ANNEXCEPTION("ID {} is out of bounds for index of size {}!", id, size());
499+
}
500+
501+
// Verify dimensions match
502+
const size_t query_size = query.size();
503+
const size_t index_vector_size = dimensions();
504+
if (query_size != index_vector_size) {
505+
throw ANNEXCEPTION(
506+
"Incompatible dimensions. Query has {} while the index expects {}.",
507+
query_size,
508+
index_vector_size
509+
);
510+
}
511+
512+
// Call extension for distance computation
513+
return svs::index::flat::extensions::get_distance_ext(data_, distance_, id, query);
514+
}
458515
};
459516

460517
///

include/svs/index/vamana/dynamic_index.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "svs/index/flat/flat.h"
2424

2525
// svs
26+
#include "svs/concepts/distance.h"
2627
#include "svs/core/data.h"
2728
#include "svs/core/distance.h"
2829
#include "svs/core/graph.h"
@@ -1214,6 +1215,35 @@ class MutableVamanaIndex {
12141215
}
12151216
}
12161217
}
1218+
1219+
///// Distance
1220+
1221+
/// @brief Compute the distance between an external vector and a vector in the index.
1222+
template <typename ExternalId, typename Query>
1223+
double get_distance(const ExternalId& external_id, const Query& query) const {
1224+
// Check if the external ID exists
1225+
if (!has_id(external_id)) {
1226+
throw ANNEXCEPTION(
1227+
"ID {} is out of bounds for index of size {}!", external_id, size()
1228+
);
1229+
}
1230+
// Verify dimensions match
1231+
const size_t query_size = query.size();
1232+
const size_t index_vector_size = dimensions();
1233+
if (query_size != index_vector_size) {
1234+
throw ANNEXCEPTION(
1235+
"Incompatible dimensions. Query has {} while the index expects {}.",
1236+
query_size,
1237+
index_vector_size
1238+
);
1239+
}
1240+
1241+
// Translate external ID to internal ID
1242+
auto internal_id = translate_external_id(external_id);
1243+
1244+
// Call extension for distance computation
1245+
return extensions::get_distance_ext(data_, distance_, internal_id, query);
1246+
}
12171247
};
12181248

12191249
///// Deduction Guides.

include/svs/index/vamana/extensions.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,4 +591,37 @@ SVS_FORCE_INLINE data::GetDatumAccessor svs_invoke(
591591
return data::GetDatumAccessor();
592592
}
593593

594+
/////
595+
///// Distance
596+
/////
597+
struct ComputeDistanceType {
598+
template <typename Data, typename Distance, typename Query>
599+
double operator()(
600+
const Data& data, const Distance& distance, size_t internal_id, const Query& query
601+
) const {
602+
return svs_invoke(*this, data, distance, internal_id, query);
603+
}
604+
};
605+
// CPO for distance computation
606+
inline constexpr ComputeDistanceType get_distance_ext{};
607+
template <typename Data, typename Distance, typename Query>
608+
double svs_invoke(
609+
svs::tag_t<get_distance_ext>,
610+
const Data& data,
611+
const Distance& distance,
612+
size_t internal_id,
613+
const Query& query
614+
) {
615+
// Get distance
616+
auto dist_f = single_search_setup(data, distance);
617+
svs::distance::maybe_fix_argument(dist_f, query);
618+
619+
// Get the vector from the index
620+
auto indexed_span = data.get_datum(internal_id);
621+
622+
// Compute the distance using the appropriate distance function
623+
auto dist = svs::distance::compute(dist_f, query, indexed_span);
624+
return static_cast<double>(dist);
625+
}
626+
594627
} // namespace svs::index::vamana::extensions

include/svs/index/vamana/index.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,29 @@ class VamanaIndex {
860860
template <typename F> void experimental_escape_hatch(F&& f) const {
861861
std::invoke(SVS_FWD(f), graph_, data_, distance_, lib::as_const_span(entry_point_));
862862
}
863+
864+
///// Distance
865+
866+
/// @brief Compute the distance between a vector in the index and a query vector
867+
template <typename Query> double get_distance(size_t id, const Query& query) const {
868+
// Check if id is valid
869+
if (id >= size()) {
870+
throw ANNEXCEPTION("ID {} is out of bounds for index of size {}!", id, size());
871+
}
872+
// Verify dimensions match
873+
const size_t query_size = query.size();
874+
const size_t index_vector_size = dimensions();
875+
if (query_size != index_vector_size) {
876+
throw ANNEXCEPTION(
877+
"Incompatible dimensions. Query has {} while the index expects {}.",
878+
query_size,
879+
index_vector_size
880+
);
881+
}
882+
883+
// Call extension for distance computation
884+
return extensions::get_distance_ext(data_, distance_, id, query);
885+
}
863886
};
864887

865888
// Shared documentation for assembly methods.

0 commit comments

Comments
 (0)