@@ -69,6 +69,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
69
69
std::mutex deleted_elements_lock; // lock for deleted_elements
70
70
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
71
71
72
+ std::mutex repair_lock; // locks graph repair
73
+
72
74
73
75
HierarchicalNSW (SpaceInterface<dist_t > *s) {
74
76
}
@@ -190,9 +192,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
190
192
}
191
193
192
194
193
- int getRandomLevel (double reverse_size ) {
195
+ int getRandomLevel (double ml ) {
194
196
std::uniform_real_distribution<double > distribution (0.0 , 1.0 );
195
- double r = -log (distribution (level_generator_)) * reverse_size ;
197
+ double r = -log (distribution (level_generator_)) * ml ;
196
198
return (int ) r;
197
199
}
198
200
@@ -240,14 +242,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
240
242
241
243
std::unique_lock <std::mutex> lock (link_list_locks_[curNodeNum]);
242
244
243
- int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
244
- if (layer == 0 ) {
245
- data = (int *)get_linklist0 (curNodeNum);
246
- } else {
247
- data = (int *)get_linklist (curNodeNum, layer);
248
- // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
249
- }
250
- size_t size = getListCount ((linklistsizeint*)data);
245
+ linklistsizeint *data = get_linklist_at_level (curNodeNum, layer);
246
+ size_t size = getListCount (data);
251
247
tableint *datal = (tableint *) (data + 1 );
252
248
#ifdef USE_SSE
253
249
_mm_prefetch ((char *) (visited_array + *(data + 1 )), _MM_HINT_T0);
@@ -325,8 +321,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
325
321
candidate_set.pop ();
326
322
327
323
tableint current_node_id = current_node_pair.second ;
328
- int *data = ( int *) get_linklist0 (current_node_id);
329
- size_t size = getListCount ((linklistsizeint*) data);
324
+ linklistsizeint *data = get_linklist0 (current_node_id);
325
+ size_t size = getListCount (data);
330
326
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
331
327
if (collect_metrics) {
332
328
metric_hops++;
@@ -471,11 +467,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
471
467
if (isUpdate) {
472
468
lock.lock ();
473
469
}
474
- linklistsizeint *ll_cur;
475
- if (level == 0 )
476
- ll_cur = get_linklist0 (cur_c);
477
- else
478
- ll_cur = get_linklist (cur_c, level);
470
+ linklistsizeint *ll_cur = get_linklist_at_level (cur_c, level);
479
471
480
472
if (*ll_cur && !isUpdate) {
481
473
throw std::runtime_error (" The newly inserted element should have blank link list" );
@@ -495,12 +487,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
495
487
for (size_t idx = 0 ; idx < selectedNeighbors.size (); idx++) {
496
488
std::unique_lock <std::mutex> lock (link_list_locks_[selectedNeighbors[idx]]);
497
489
498
- linklistsizeint *ll_other;
499
- if (level == 0 )
500
- ll_other = get_linklist0 (selectedNeighbors[idx]);
501
- else
502
- ll_other = get_linklist (selectedNeighbors[idx], level);
503
-
490
+ linklistsizeint *ll_other = get_linklist_at_level (selectedNeighbors[idx], level);
504
491
size_t sz_link_list_other = getListCount (ll_other);
505
492
506
493
if (sz_link_list_other > Mcurmax)
@@ -969,8 +956,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
969
956
970
957
{
971
958
std::unique_lock <std::mutex> lock (link_list_locks_[neigh]);
972
- linklistsizeint *ll_cur;
973
- ll_cur = get_linklist_at_level (neigh, layer);
959
+ linklistsizeint *ll_cur = get_linklist_at_level (neigh, layer);
974
960
size_t candSize = candidates.size ();
975
961
setListCount (ll_cur, candSize);
976
962
tableint *data = (tableint *) (ll_cur + 1 );
@@ -999,7 +985,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
999
985
bool changed = true ;
1000
986
while (changed) {
1001
987
changed = false ;
1002
- unsigned int *data;
988
+ linklistsizeint *data;
1003
989
std::unique_lock <std::mutex> lock (link_list_locks_[currObj]);
1004
990
data = get_linklist_at_level (currObj, level);
1005
991
int size = getListCount (data);
@@ -1057,7 +1043,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1057
1043
1058
1044
std::vector<tableint> getConnectionsWithLock (tableint internalId, int level) {
1059
1045
std::unique_lock <std::mutex> lock (link_list_locks_[internalId]);
1060
- unsigned int *data = get_linklist_at_level (internalId, level);
1046
+ linklistsizeint *data = get_linklist_at_level (internalId, level);
1061
1047
int size = getListCount (data);
1062
1048
std::vector<tableint> result (size);
1063
1049
tableint *ll = (tableint *) (data + 1 );
@@ -1095,6 +1081,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1095
1081
}
1096
1082
1097
1083
cur_c = cur_element_count;
1084
+ // use the element level as a flag to show that an element is not added yet
1085
+ // the element count is increased but no lock is aquired
1086
+ // so someone can start using the new element
1087
+ element_levels_[cur_c] = -1 ;
1098
1088
cur_element_count++;
1099
1089
label_lookup_[label] = cur_c;
1100
1090
}
@@ -1134,7 +1124,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1134
1124
bool changed = true ;
1135
1125
while (changed) {
1136
1126
changed = false ;
1137
- unsigned int *data;
1127
+ linklistsizeint *data;
1138
1128
std::unique_lock <std::mutex> lock (link_list_locks_[currObj]);
1139
1129
data = get_linklist (currObj, level);
1140
1130
int size = getListCount (data);
@@ -1196,9 +1186,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1196
1186
bool changed = true ;
1197
1187
while (changed) {
1198
1188
changed = false ;
1199
- unsigned int *data;
1200
-
1201
- data = (unsigned int *) get_linklist (currObj, level);
1189
+ linklistsizeint *data = get_linklist (currObj, level);
1202
1190
int size = getListCount (data);
1203
1191
metric_hops++;
1204
1192
metric_distance_computations+=size;
@@ -1271,5 +1259,110 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1271
1259
}
1272
1260
std::cout << " integrity ok, checked " << connections_checked << " connections\n " ;
1273
1261
}
1262
+
1263
+
1264
+ void repair_zero_indegree () {
1265
+ // only one repair is allowed to be in progress at a time
1266
+ std::unique_lock <std::mutex> lock_repair (repair_lock);
1267
+
1268
+ int maxlevel_copy = maxlevel_;
1269
+ size_t element_count_copy = cur_element_count;
1270
+ std::vector<size_t > indegree (element_count_copy);
1271
+
1272
+ for (int level = maxlevel_copy; level >=0 ; level--) {
1273
+ std::fill (indegree.begin (), indegree.end (), 0 );
1274
+
1275
+ size_t m_max = level ? maxM_ : maxM0_;
1276
+ int num_elements = 0 ;
1277
+ // calculate in-degree
1278
+ for (tableint internal_id = 0 ; internal_id < element_count_copy; internal_id++) {
1279
+ // lock until addition is finished
1280
+ std::unique_lock <std::mutex> lock_el (link_list_locks_[internal_id]);
1281
+ // skip elements that are not in the current level
1282
+ // Note: if the element was not added to the graph before the lock
1283
+ // then element_level = -1 and we skip it as well
1284
+ int element_level = element_levels_[internal_id];
1285
+ if (element_level < level) {
1286
+ continue ;
1287
+ }
1288
+
1289
+ linklistsizeint *ll = get_linklist_at_level (internal_id, level);
1290
+ int size = getListCount (ll);
1291
+ tableint *datal = (tableint *) (ll + 1 );
1292
+ for (int i = 0 ; i < size; i++) {
1293
+ tableint nei_id = datal[i];
1294
+ // skip newly added elements
1295
+ if (nei_id >= element_count_copy) {
1296
+ continue ;
1297
+ }
1298
+ indegree[nei_id] += 1 ;
1299
+ }
1300
+ num_elements += 1 ;
1301
+ }
1302
+
1303
+ // skip levels with 1 element
1304
+ if (num_elements <= 1 ) {
1305
+ continue ;
1306
+ }
1307
+
1308
+ // fix elements with 0 in-degree
1309
+ for (tableint internal_id = 0 ; internal_id < element_count_copy; internal_id++) {
1310
+ int element_level = element_levels_[internal_id];
1311
+ if (element_level < level || indegree[internal_id] > 0 ) {
1312
+ continue ;
1313
+ }
1314
+
1315
+ char * data_point = getDataByInternalId (internal_id);
1316
+ tableint currObj = enterpoint_node_;
1317
+
1318
+ dist_t curdist = fstdistfunc_ (data_point, getDataByInternalId (currObj), dist_func_param_);
1319
+ for (int level_above = maxlevel_copy; level_above > level; level_above--) {
1320
+ bool changed = true ;
1321
+ while (changed) {
1322
+ changed = false ;
1323
+ linklistsizeint *data;
1324
+ std::unique_lock <std::mutex> lock (link_list_locks_[currObj]);
1325
+ data = get_linklist_at_level (currObj, level_above);
1326
+ int size = getListCount (data);
1327
+
1328
+ tableint *datal = (tableint *) (data + 1 );
1329
+ for (int i = 0 ; i < size; i++) {
1330
+ tableint cand = datal[i];
1331
+ dist_t d = fstdistfunc_ (data_point, getDataByInternalId (cand), dist_func_param_);
1332
+ if (d < curdist) {
1333
+ curdist = d;
1334
+ currObj = cand;
1335
+ changed = true ;
1336
+ }
1337
+ }
1338
+ }
1339
+ }
1340
+
1341
+ std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> candidates = searchBaseLayer (
1342
+ currObj, data_point, level);
1343
+
1344
+ while (candidates.size () > 0 ) {
1345
+ tableint cand_id = candidates.top ().second ;
1346
+ // skip same element
1347
+ if (cand_id == internal_id) {
1348
+ candidates.pop ();
1349
+ continue ;
1350
+ }
1351
+
1352
+ // try to connect candidate to the element
1353
+ // add an edge if there is space
1354
+ std::unique_lock <std::mutex> lock (link_list_locks_[cand_id]);
1355
+ linklistsizeint *ll_cand = get_linklist_at_level (cand_id, level);
1356
+ tableint *data_cand = (tableint *) (ll_cand + 1 );
1357
+ size_t size = getListCount (ll_cand);
1358
+ if (size < m_max) {
1359
+ data_cand[size] = internal_id;
1360
+ setListCount (ll_cand, size + 1 );
1361
+ }
1362
+ candidates.pop ();
1363
+ }
1364
+ }
1365
+ }
1366
+ }
1274
1367
};
1275
1368
} // namespace hnswlib
0 commit comments