Skip to content

Commit 5100d3f

Browse files
committed
Repair graph method nmslib#515
1 parent 5aba40d commit 5100d3f

File tree

4 files changed

+253
-31
lines changed

4 files changed

+253
-31
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,5 @@ jobs:
7373
./multiThread_replace_test
7474
./test_updates
7575
./test_updates update
76+
./repair_test
7677
shell: bash

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,7 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
5353

5454
add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
5555
target_link_libraries(main hnswlib)
56+
57+
add_executable(repair_test tests/cpp/repair_test.cpp)
58+
target_link_libraries(repair_test hnswlib)
5659
endif()

hnswlib/hnswalg.h

Lines changed: 124 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
6969
std::mutex deleted_elements_lock; // lock for deleted_elements
7070
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
7171

72+
std::mutex repair_lock; // locks graph repair
73+
7274

7375
HierarchicalNSW(SpaceInterface<dist_t> *s) {
7476
}
@@ -190,9 +192,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
190192
}
191193

192194

193-
int getRandomLevel(double reverse_size) {
195+
int getRandomLevel(double ml) {
194196
std::uniform_real_distribution<double> distribution(0.0, 1.0);
195-
double r = -log(distribution(level_generator_)) * reverse_size;
197+
double r = -log(distribution(level_generator_)) * ml;
196198
return (int) r;
197199
}
198200

@@ -240,14 +242,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
240242

241243
std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);
242244

243-
int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
244-
if (layer == 0) {
245-
data = (int*)get_linklist0(curNodeNum);
246-
} else {
247-
data = (int*)get_linklist(curNodeNum, layer);
248-
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
249-
}
250-
size_t size = getListCount((linklistsizeint*)data);
245+
linklistsizeint *data = get_linklist_at_level(curNodeNum, layer);
246+
size_t size = getListCount(data);
251247
tableint *datal = (tableint *) (data + 1);
252248
#ifdef USE_SSE
253249
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
@@ -325,8 +321,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
325321
candidate_set.pop();
326322

327323
tableint current_node_id = current_node_pair.second;
328-
int *data = (int *) get_linklist0(current_node_id);
329-
size_t size = getListCount((linklistsizeint*)data);
324+
linklistsizeint *data = get_linklist0(current_node_id);
325+
size_t size = getListCount(data);
330326
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
331327
if (collect_metrics) {
332328
metric_hops++;
@@ -471,11 +467,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
471467
if (isUpdate) {
472468
lock.lock();
473469
}
474-
linklistsizeint *ll_cur;
475-
if (level == 0)
476-
ll_cur = get_linklist0(cur_c);
477-
else
478-
ll_cur = get_linklist(cur_c, level);
470+
linklistsizeint *ll_cur = get_linklist_at_level(cur_c, level);
479471

480472
if (*ll_cur && !isUpdate) {
481473
throw std::runtime_error("The newly inserted element should have blank link list");
@@ -495,12 +487,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
495487
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
496488
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);
497489

498-
linklistsizeint *ll_other;
499-
if (level == 0)
500-
ll_other = get_linklist0(selectedNeighbors[idx]);
501-
else
502-
ll_other = get_linklist(selectedNeighbors[idx], level);
503-
490+
linklistsizeint *ll_other = get_linklist_at_level(selectedNeighbors[idx], level);
504491
size_t sz_link_list_other = getListCount(ll_other);
505492

506493
if (sz_link_list_other > Mcurmax)
@@ -969,8 +956,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
969956

970957
{
971958
std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
972-
linklistsizeint *ll_cur;
973-
ll_cur = get_linklist_at_level(neigh, layer);
959+
linklistsizeint *ll_cur = get_linklist_at_level(neigh, layer);
974960
size_t candSize = candidates.size();
975961
setListCount(ll_cur, candSize);
976962
tableint *data = (tableint *) (ll_cur + 1);
@@ -999,7 +985,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
999985
bool changed = true;
1000986
while (changed) {
1001987
changed = false;
1002-
unsigned int *data;
988+
linklistsizeint *data;
1003989
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1004990
data = get_linklist_at_level(currObj, level);
1005991
int size = getListCount(data);
@@ -1057,7 +1043,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
10571043

10581044
std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
10591045
std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
1060-
unsigned int *data = get_linklist_at_level(internalId, level);
1046+
linklistsizeint *data = get_linklist_at_level(internalId, level);
10611047
int size = getListCount(data);
10621048
std::vector<tableint> result(size);
10631049
tableint *ll = (tableint *) (data + 1);
@@ -1095,6 +1081,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
10951081
}
10961082

10971083
cur_c = cur_element_count;
1084+
// use the element level as a flag to show that an element is not added yet
1085+
// the element count is increased but no lock is aquired
1086+
// so someone can start using the new element
1087+
element_levels_[cur_c] = -1;
10981088
cur_element_count++;
10991089
label_lookup_[label] = cur_c;
11001090
}
@@ -1134,7 +1124,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11341124
bool changed = true;
11351125
while (changed) {
11361126
changed = false;
1137-
unsigned int *data;
1127+
linklistsizeint *data;
11381128
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
11391129
data = get_linklist(currObj, level);
11401130
int size = getListCount(data);
@@ -1196,9 +1186,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11961186
bool changed = true;
11971187
while (changed) {
11981188
changed = false;
1199-
unsigned int *data;
1200-
1201-
data = (unsigned int *) get_linklist(currObj, level);
1189+
linklistsizeint *data = get_linklist(currObj, level);
12021190
int size = getListCount(data);
12031191
metric_hops++;
12041192
metric_distance_computations+=size;
@@ -1271,5 +1259,110 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12711259
}
12721260
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
12731261
}
1262+
1263+
1264+
void repair_zero_indegree() {
1265+
// only one repair is allowed to be in progress at a time
1266+
std::unique_lock <std::mutex> lock_repair(repair_lock);
1267+
1268+
int maxlevel_copy = maxlevel_;
1269+
size_t element_count_copy = cur_element_count;
1270+
std::vector<size_t> indegree(element_count_copy);
1271+
1272+
for (int level = maxlevel_copy; level >=0 ; level--) {
1273+
std::fill(indegree.begin(), indegree.end(), 0);
1274+
1275+
size_t m_max = level ? maxM_ : maxM0_;
1276+
int num_elements = 0;
1277+
// calculate in-degree
1278+
for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) {
1279+
// lock until addition is finished
1280+
std::unique_lock <std::mutex> lock_el(link_list_locks_[internal_id]);
1281+
// skip elements that are not in the current level
1282+
// Note: if the element was not added to the graph before the lock
1283+
// then element_level = -1 and we skip it as well
1284+
int element_level = element_levels_[internal_id];
1285+
if (element_level < level) {
1286+
continue;
1287+
}
1288+
1289+
linklistsizeint *ll = get_linklist_at_level(internal_id, level);
1290+
int size = getListCount(ll);
1291+
tableint *datal = (tableint *) (ll + 1);
1292+
for (int i = 0; i < size; i++) {
1293+
tableint nei_id = datal[i];
1294+
// skip newly added elements
1295+
if (nei_id >= element_count_copy) {
1296+
continue;
1297+
}
1298+
indegree[nei_id] += 1;
1299+
}
1300+
num_elements += 1;
1301+
}
1302+
1303+
// skip levels with 1 element
1304+
if (num_elements <= 1) {
1305+
continue;
1306+
}
1307+
1308+
// fix elements with 0 in-degree
1309+
for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) {
1310+
int element_level = element_levels_[internal_id];
1311+
if (element_level < level || indegree[internal_id] > 0) {
1312+
continue;
1313+
}
1314+
1315+
char* data_point = getDataByInternalId(internal_id);
1316+
tableint currObj = enterpoint_node_;
1317+
1318+
dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
1319+
for (int level_above = maxlevel_copy; level_above > level; level_above--) {
1320+
bool changed = true;
1321+
while (changed) {
1322+
changed = false;
1323+
linklistsizeint *data;
1324+
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
1325+
data = get_linklist_at_level(currObj, level_above);
1326+
int size = getListCount(data);
1327+
1328+
tableint *datal = (tableint *) (data + 1);
1329+
for (int i = 0; i < size; i++) {
1330+
tableint cand = datal[i];
1331+
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
1332+
if (d < curdist) {
1333+
curdist = d;
1334+
currObj = cand;
1335+
changed = true;
1336+
}
1337+
}
1338+
}
1339+
}
1340+
1341+
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates = searchBaseLayer(
1342+
currObj, data_point, level);
1343+
1344+
while (candidates.size() > 0) {
1345+
tableint cand_id = candidates.top().second;
1346+
// skip same element
1347+
if (cand_id == internal_id) {
1348+
candidates.pop();
1349+
continue;
1350+
}
1351+
1352+
// try to connect candidate to the element
1353+
// add an edge if there is space
1354+
std::unique_lock <std::mutex> lock(link_list_locks_[cand_id]);
1355+
linklistsizeint *ll_cand = get_linklist_at_level(cand_id, level);
1356+
tableint *data_cand = (tableint *) (ll_cand + 1);
1357+
size_t size = getListCount(ll_cand);
1358+
if (size < m_max) {
1359+
data_cand[size] = internal_id;
1360+
setListCount(ll_cand, size + 1);
1361+
}
1362+
candidates.pop();
1363+
}
1364+
}
1365+
}
1366+
}
12741367
};
12751368
} // namespace hnswlib

tests/cpp/repair_test.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#include "../../hnswlib/hnswlib.h"
2+
#include <thread>
3+
4+
5+
bool is_indegree_ok(hnswlib::HierarchicalNSW<float>* alg_hnsw) {
6+
bool is_ok_flag = true;
7+
std::vector<int> indegree(alg_hnsw->cur_element_count);
8+
9+
for (int level = alg_hnsw->maxlevel_; level >=0 ; level--) {
10+
std::fill(indegree.begin(), indegree.end(), 0);
11+
int num_elements = 0;
12+
// calculate in-degree
13+
for (int internal_id = 0; internal_id < alg_hnsw->cur_element_count; internal_id++) {
14+
int element_level = alg_hnsw->element_levels_[internal_id];
15+
if (element_level < level) {
16+
continue;
17+
}
18+
std::vector<hnswlib::tableint> neis = alg_hnsw->getConnectionsWithLock(internal_id, level);
19+
for (hnswlib::tableint nei : neis) {
20+
indegree[nei] += 1;
21+
}
22+
num_elements += 1;
23+
}
24+
// skip levels with 1 element
25+
if (num_elements <= 1) {
26+
continue;
27+
}
28+
29+
// check in-degree
30+
for (int internal_id = 0; internal_id < alg_hnsw->cur_element_count; internal_id++) {
31+
int element_level = alg_hnsw->element_levels_[internal_id];
32+
if (element_level < level) {
33+
continue;
34+
}
35+
if (indegree[internal_id] == 0) {
36+
std::cout << "zero in-degree node found, level=" << level << " id=" << internal_id << "\n" << std::flush;
37+
is_ok_flag = false;
38+
}
39+
}
40+
}
41+
42+
return is_ok_flag;
43+
}
44+
45+
46+
int main() {
47+
int dim = 4; // Dimension of the elements
48+
int n = 1000; // Maximum number of elements, should be known beforehand
49+
int M = 8; // Tightly connected with internal dimensionality of the data
50+
// strongly affects the memory consumption
51+
int ef_construction = 200; // Controls index search speed/build speed tradeoff
52+
int num_test_iter = 5;
53+
54+
int test_id = 0;
55+
56+
std::mt19937 rng;
57+
rng.seed(47);
58+
std::uniform_real_distribution<> distrib_real;
59+
while (test_id < num_test_iter) {
60+
// Initing index
61+
std::cout << "Initing index" << std::endl;
62+
hnswlib::L2Space space(dim);
63+
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n * 3, M, ef_construction, 100, true);
64+
65+
// Generate random data
66+
float* data = new float[dim * n];
67+
for (int i = 0; i < dim * n; i++) {
68+
data[i] = distrib_real(rng);
69+
}
70+
71+
std::cout << "Add data to index" << std::endl;
72+
73+
// Add data to index
74+
for (int i = 0; i < n; i++) {
75+
//std::cout << "insert " << i << std::endl;
76+
alg_hnsw->addPoint(data + i * dim, i, true);
77+
}
78+
std::cout << test_id << " Index is ready\n";
79+
80+
std::vector<std::thread> threads;
81+
82+
// mix new inserts with modifications (50% of operations are new)
83+
for(int i = 0; i < n; i += 100) {
84+
//std::cout << "mixed insert " << i << std::endl;
85+
86+
threads.emplace_back([alg_hnsw, data, i, dim]() {
87+
std::uniform_real_distribution<> distrib_real;
88+
std::mt19937 rng;
89+
rng.seed(49);
90+
91+
for(auto j = 0; j < 10; j++) {
92+
auto actual_index = i + j;
93+
auto id = ( actual_index % 2 != 0) ? actual_index + 10000 : actual_index;
94+
std::vector<float> values;
95+
for (size_t j = 0; j < dim; j++) {
96+
values.push_back(distrib_real(rng) + 0.01);
97+
}
98+
alg_hnsw->addPoint(values.data(), id, true);
99+
}
100+
});
101+
}
102+
103+
// add repair method to check concurrency
104+
threads.emplace_back([alg_hnsw] {
105+
alg_hnsw->repair_zero_indegree();
106+
});
107+
108+
for(auto& t: threads) {
109+
t.join();
110+
}
111+
112+
bool is_ok_before_flag = is_indegree_ok(alg_hnsw);
113+
// fix in-degree if it is broken
114+
if (!is_ok_before_flag) {
115+
alg_hnsw->repair_zero_indegree();
116+
}
117+
bool is_ok_after_flag = is_indegree_ok(alg_hnsw);
118+
assert(is_ok_after_flag);
119+
test_id += 1;
120+
121+
delete[] data;
122+
delete alg_hnsw;
123+
}
124+
return 0;
125+
}

0 commit comments

Comments
 (0)