Skip to content

Commit 1415e05

Browse files
committed
Fix race condition in free-threaded Python (fixes issue #867)
This commit addresses an issue arising when multiple threads want to access the Python object associated with the same C++ instance, which does not exist yet and therefore must be created. @vfdev-5 reported that TSAN detects a race condition in code that uses this pattern, caused by concurrent unprotected reads/writes of internal ``nb_inst`` fields. There is also a larger problem: depending on how operations are sequenced, it is possible that two threads simultaneously create a Python wrapper, which violates the usual invariant that each (C++ instance pointer, type) pair maps to at most one Python object. This PR updates nanobind to preserve this invariant. When registering a newly created wrapper object in the internal data structures, nanobind checks if another equivalent wrapper has been created in the meantime. If so, we destroy the thread's instance and return the registered one. This requires some extra handling code, that, however, only runs with very low probability. It also adds a new ``registered`` bit flag to ``nb_inst``, which makes it possible to have ``nb_inst`` objects that aren't registered in the internal data structures. I am planning to use that feature to fix the (unrelated) issue #879.
1 parent 534fd8c commit 1415e05

File tree

3 files changed

+66
-16
lines changed

3 files changed

+66
-16
lines changed

src/nb_internals.h

+14-3
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,11 @@ struct nb_inst { // usually: 24 bytes
8787
/// Does this instance use intrusive reference counting?
8888
uint32_t intrusive : 1;
8989

90+
/// Is this instance registered in the nanobind instance map?
91+
uint32_t registered : 1;
92+
9093
// That's a lot of unused space. I wonder if there is a good use for it..
91-
uint32_t unused : 24;
94+
uint32_t unused : 23;
9295
};
9396

9497
static_assert(sizeof(nb_inst) == sizeof(PyObject) + sizeof(uint32_t) * 2);
@@ -489,8 +492,16 @@ template <typename T> struct scoped_pymalloc {
489492
#if defined(NB_FREE_THREADED)
490493
struct lock_shard {
491494
nb_shard &s;
492-
lock_shard(nb_shard &s) : s(s) { PyMutex_Lock(&s.mutex); }
493-
~lock_shard() { PyMutex_Unlock(&s.mutex); }
495+
bool active;
496+
lock_shard(nb_shard &s) : s(s), active(true) { PyMutex_Lock(&s.mutex); }
497+
~lock_shard() { unlock(); }
498+
499+
void unlock() {
500+
if (active) {
501+
PyMutex_Unlock(&s.mutex);
502+
active = false;
503+
}
504+
}
494505
};
495506
struct lock_internals {
496507
nb_internals *i;

src/nb_type.cpp

+48-9
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */,
9898
self->cpp_delete = 0;
9999
self->clear_keep_alive = 0;
100100
self->intrusive = intrusive;
101+
self->registered = 1;
101102
self->unused = 0;
103+
104+
// Make the object compatible with nb_try_inc_ref (free-threaded builds only)
102105
nb_enable_try_inc_ref((PyObject *) self);
103106

104107
// Update hash table that maps from C++ to Python instance
@@ -111,7 +114,9 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */,
111114
return (PyObject *) self;
112115
}
113116

114-
/// Allocate memory for a nb_type instance with external storage
117+
/// Allocate memory for a nb_type instance with external storage. In contrast to
118+
/// 'inst_new_int()', this does not yet register the instance in the internal
119+
/// data structures. The function 'inst_register()' must be used to do so.
115120
PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
116121
bool gc = PyType_HasFeature(tp, Py_TPFLAGS_HAVE_GC);
117122

@@ -164,14 +169,27 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
164169
self->cpp_delete = 0;
165170
self->clear_keep_alive = 0;
166171
self->intrusive = intrusive;
172+
173+
// We already set this flag to 1 here so that we don't have to change it again
174+
// afterwards. This requires that the call to 'inst_new_ext' is paired with
175+
// a call to 'inst_register()'
176+
177+
self->registered = 1;
167178
self->unused = 0;
179+
180+
// Make the object compatible with nb_try_inc_ref (free-threaded builds only)
168181
nb_enable_try_inc_ref((PyObject *) self);
169182

183+
return (PyObject *) self;
184+
}
185+
186+
/// Register the object constructed by 'inst_new_ext()' in the internal data structures
187+
static nb_inst *inst_register(nb_inst *inst, void *value) noexcept {
170188
nb_shard &shard = internals->shard(value);
171189
lock_shard guard(shard);
172190

173191
// Update hash table that maps from C++ to Python instance
174-
auto [it, success] = shard.inst_c2p.try_emplace(value, self);
192+
auto [it, success] = shard.inst_c2p.try_emplace(value, inst);
175193

176194
if (NB_UNLIKELY(!success)) {
177195
void *entry = it->second;
@@ -186,27 +204,45 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
186204
entry = it.value() = nb_mark_seq(first);
187205
}
188206

207+
PyTypeObject *tp = Py_TYPE(inst);
189208
nb_inst_seq *seq = nb_get_seq(entry);
190209
while (true) {
191-
check((nb_inst *) seq->inst != self,
192-
"nanobind::detail::inst_new_ext(): duplicate instance!");
210+
nb_inst *inst_2 = (nb_inst *) seq->inst;
211+
PyTypeObject *tp_2 = Py_TYPE(inst_2);
212+
213+
// The following should never happen
214+
check(inst_2 != inst, "nanobind::detail::inst_new_ext(): duplicate instance!");
215+
216+
// In the case of concurrent execution, another thread might have created an
217+
// identical instance wrapper in the meantime. Let's return that one then.
218+
if (tp == tp_2 &&
219+
!(nb_type_data(tp_2)->flags & (uint32_t) type_flags::is_python_type) &&
220+
nb_try_inc_ref((PyObject *) inst_2)) {
221+
inst->destruct = inst->cpp_delete = inst->registered = false;
222+
guard.unlock();
223+
Py_DECREF(inst);
224+
return inst_2;
225+
}
226+
193227
if (!seq->next)
194228
break;
229+
195230
seq = seq->next;
196231
}
197232

198233
nb_inst_seq *next = (nb_inst_seq *) PyMem_Malloc(sizeof(nb_inst_seq));
199234
check(next,
200235
"nanobind::detail::inst_new_ext(): list element allocation failed!");
201236

202-
next->inst = (PyObject *) self;
237+
next->inst = (PyObject *) inst;
203238
next->next = nullptr;
204239
seq->next = next;
205240
}
206241

207-
return (PyObject *) self;
242+
return inst;
208243
}
209244

245+
210246
static void inst_dealloc(PyObject *self) {
211247
PyTypeObject *tp = Py_TYPE(self);
212248
const type_data *t = nb_type_data(tp);
@@ -253,7 +289,7 @@ static void inst_dealloc(PyObject *self) {
253289

254290
nb_weakref_seq *wr_seq = nullptr;
255291

256-
{
292+
if (inst->registered) {
257293
// Enter critical section of shard
258294
nb_shard &shard = internals->shard(p);
259295
lock_shard guard(shard);
@@ -1737,6 +1773,9 @@ static PyObject *nb_type_put_common(void *value, type_data *t, rv_policy rvp,
17371773
if (intrusive)
17381774
t->set_self_py(new_value, (PyObject *) inst);
17391775

1776+
if (!create_new)
1777+
inst = inst_register(inst, value);
1778+
17401779
return (PyObject *) inst;
17411780
}
17421781

@@ -2082,7 +2121,7 @@ PyObject *nb_inst_reference(PyTypeObject *t, void *ptr, PyObject *parent) {
20822121
nbi->state = nb_inst::state_ready;
20832122
if (parent)
20842123
keep_alive(result, parent);
2085-
return result;
2124+
return (PyObject *) inst_register(nbi, ptr);
20862125
}
20872126

20882127
PyObject *nb_inst_take_ownership(PyTypeObject *t, void *ptr) {
@@ -2092,7 +2131,7 @@ PyObject *nb_inst_take_ownership(PyTypeObject *t, void *ptr) {
20922131
nb_inst *nbi = (nb_inst *) result;
20932132
nbi->destruct = nbi->cpp_delete = true;
20942133
nbi->state = nb_inst::state_ready;
2095-
return result;
2134+
return (PyObject *) inst_register(nbi, ptr);
20962135
}
20972136

20982137
void *nb_inst_ptr(PyObject *o) noexcept {

tests/test_thread.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test02_global_lock(n_threads=8):
2929
n = 100000
3030
c = Counter()
3131
def f():
32-
for i in range(n):
32+
for _ in range(n):
3333
t.inc_global(c)
3434

3535
parallelize(f, n_threads=n_threads)
@@ -53,7 +53,7 @@ def test04_locked_function(n_threads=8):
5353
n = 100000
5454
c = Counter()
5555
def f():
56-
for i in range(n):
56+
for _ in range(n):
5757
t.inc_safe(c)
5858

5959
parallelize(f, n_threads=n_threads)
@@ -77,11 +77,11 @@ def f():
7777
assert c.value == n * n_threads
7878

7979

80-
def test_06_global_wrapper(n_threads=8):
80+
def test06_global_wrapper(n_threads=8):
8181
# Check wrapper lookup racing with wrapper deallocation
8282
n = 10000
8383
def f():
84-
for i in range(n):
84+
for _ in range(n):
8585
GlobalData.get()
8686
GlobalData.get()
8787
GlobalData.get()

0 commit comments

Comments
 (0)