Skip to content

Commit c786d34

Browse files
XuehaiPanrwgkhenryiii
authored
fix: handle null py::handle and add tests for py::scoped_critical_section (#5706)
* chore: handle null for `py::scoped_critical_section` * test: add tests for `py::scoped_critical_section` * test: use assert instead of REQUIRE * feat: enable faulthandler for pytest * chore: use `__has_include(<barrier>)` * fix: fix segmentation fault in test * fix: test critical_section for no-gil only * test: run new tests only * test: ensure non-empty test selection * fix: fix test critical_section * fix: change Python 3.14.0b1/2 xfail tests to non-strict * test: trigger gc manually * test: mark xfail to `DynamicClass` * Use `namespace test_scoped_critical_section_ns` (standard approach to guard against name clashes). * Simplify changes in pybind11/critical_section.h and add test_nullptr_combinations() * test: disable Python devmode in pytest * test: add comprehensive comments for the tests * test: add a summary comment for tests * refactor: simpler impl Signed-off-by: Henry Schreiner <[email protected]> --------- Signed-off-by: Henry Schreiner <[email protected]> Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]> Co-authored-by: Henry Schreiner <[email protected]>
1 parent c7026d0 commit c786d34

File tree

4 files changed

+324
-12
lines changed

4 files changed

+324
-12
lines changed

include/pybind11/critical_section.h

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,30 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
1313
class scoped_critical_section {
1414
public:
1515
#ifdef Py_GIL_DISABLED
16-
explicit scoped_critical_section(handle obj) : has2(false) {
17-
PyCriticalSection_Begin(&section, obj.ptr());
18-
}
19-
20-
scoped_critical_section(handle obj1, handle obj2) : has2(true) {
21-
PyCriticalSection2_Begin(&section2, obj1.ptr(), obj2.ptr());
16+
explicit scoped_critical_section(handle obj1, handle obj2 = handle{}) {
17+
if (obj1) {
18+
if (obj2) {
19+
PyCriticalSection2_Begin(&section2, obj1.ptr(), obj2.ptr());
20+
rank = 2;
21+
} else {
22+
PyCriticalSection_Begin(&section, obj1.ptr());
23+
rank = 1;
24+
}
25+
} else if (obj2) {
26+
PyCriticalSection_Begin(&section, obj2.ptr());
27+
rank = 1;
28+
}
2229
}
2330

2431
~scoped_critical_section() {
25-
if (has2) {
26-
PyCriticalSection2_End(&section2);
27-
} else {
32+
if (rank == 1) {
2833
PyCriticalSection_End(&section);
34+
} else if (rank == 2) {
35+
PyCriticalSection2_End(&section2);
2936
}
3037
}
3138
#else
32-
explicit scoped_critical_section(handle) {};
33-
scoped_critical_section(handle, handle) {};
39+
explicit scoped_critical_section(handle, handle = handle{}) {};
3440
~scoped_critical_section() = default;
3541
#endif
3642

@@ -39,7 +45,7 @@ class scoped_critical_section {
3945

4046
private:
4147
#ifdef Py_GIL_DISABLED
42-
bool has2;
48+
int rank{0};
4349
union {
4450
PyCriticalSection section;
4551
PyCriticalSection2 section2;

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ set(PYBIND11_TEST_FILES
166166
test_potentially_slicing_weak_ptr
167167
test_python_multiple_inheritance
168168
test_pytypes
169+
test_scoped_critical_section
169170
test_sequences_and_iterators
170171
test_smart_ptr
171172
test_stl
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
#include <pybind11/critical_section.h>
2+
3+
#include "pybind11_tests.h"
4+
5+
#include <atomic>
6+
#include <chrono>
7+
#include <thread>
8+
#include <utility>
9+
10+
#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include(<barrier>)
11+
# define PYBIND11_HAS_BARRIER 1
12+
# include <barrier>
13+
#endif
14+
15+
namespace test_scoped_critical_section_ns {
16+
17+
void test_one_nullptr() { py::scoped_critical_section lock{py::handle{}}; }
18+
19+
void test_two_nullptrs() { py::scoped_critical_section lock{py::handle{}, py::handle{}}; }
20+
21+
void test_first_nullptr() {
22+
py::dict d;
23+
py::scoped_critical_section lock{py::handle{}, d};
24+
}
25+
26+
void test_second_nullptr() {
27+
py::dict d;
28+
py::scoped_critical_section lock{d, py::handle{}};
29+
}
30+
31+
// Referenced test implementation: https://github.com/PyO3/pyo3/blob/v0.25.0/src/sync.rs#L874
32+
class BoolWrapper {
33+
public:
34+
explicit BoolWrapper(bool value) : value_{value} {}
35+
bool get() const { return value_.load(std::memory_order_acquire); }
36+
void set(bool value) { value_.store(value, std::memory_order_release); }
37+
38+
private:
39+
std::atomic_bool value_{false};
40+
};
41+
42+
#if defined(PYBIND11_HAS_BARRIER)
43+
44+
// Modifying the C/C++ members of a Python object from multiple threads requires a critical section
45+
// to ensure thread safety and data integrity.
46+
// These tests use a scoped critical section to ensure that the Python object is accessed in a
47+
// thread-safe manner.
48+
49+
void test_scoped_critical_section(const py::handle &cls) {
50+
auto barrier = std::barrier(2);
51+
auto bool_wrapper = cls(false);
52+
bool output = false;
53+
54+
{
55+
// Release the GIL to allow run threads in parallel.
56+
py::gil_scoped_release gil_release{};
57+
58+
std::thread t1([&]() {
59+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
60+
// before entering the critical section. Otherwise, the critical section
61+
// will cause a segmentation fault.
62+
py::gil_scoped_acquire ensure_tstate{};
63+
// Enter the critical section with the same object as the second thread.
64+
py::scoped_critical_section lock{bool_wrapper};
65+
// At this point, the object is locked by this thread via the scoped_critical_section.
66+
// This barrier will ensure that the second thread waits until this thread has released
67+
// the critical section before proceeding.
68+
barrier.arrive_and_wait();
69+
// Sleep for a short time to simulate some work in the critical section.
70+
// This sleep is necessary to test the locking mechanism properly.
71+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
72+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
73+
bw->set(true);
74+
});
75+
76+
std::thread t2([&]() {
77+
// This thread will wait until the first thread has entered the critical section due to
78+
// the barrier.
79+
barrier.arrive_and_wait();
80+
{
81+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
82+
// before entering the critical section. Otherwise, the critical section
83+
// will cause a segmentation fault.
84+
py::gil_scoped_acquire ensure_tstate{};
85+
// Enter the critical section with the same object as the first thread.
86+
py::scoped_critical_section lock{bool_wrapper};
87+
// At this point, the critical section is released by the first thread, the value
88+
// is set to true.
89+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
90+
output = bw->get();
91+
}
92+
});
93+
94+
t1.join();
95+
t2.join();
96+
}
97+
98+
if (!output) {
99+
throw std::runtime_error("Scoped critical section test failed: output is false");
100+
}
101+
}
102+
103+
void test_scoped_critical_section2(const py::handle &cls) {
104+
auto barrier = std::barrier(3);
105+
auto bool_wrapper1 = cls(false);
106+
auto bool_wrapper2 = cls(false);
107+
std::pair<bool, bool> output{false, false};
108+
109+
{
110+
// Release the GIL to allow run threads in parallel.
111+
py::gil_scoped_release gil_release{};
112+
113+
std::thread t1([&]() {
114+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
115+
// before entering the critical section. Otherwise, the critical section
116+
// will cause a segmentation fault.
117+
py::gil_scoped_acquire ensure_tstate{};
118+
// Enter the critical section with two different objects.
119+
// This will ensure that the critical section is locked for both objects.
120+
py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2};
121+
// At this point, objects are locked by this thread via the scoped_critical_section.
122+
// This barrier will ensure that other threads wait until this thread has released
123+
// the critical section before proceeding.
124+
barrier.arrive_and_wait();
125+
// Sleep for a short time to simulate some work in the critical section.
126+
// This sleep is necessary to test the locking mechanism properly.
127+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
128+
auto *bw1 = bool_wrapper1.cast<BoolWrapper *>();
129+
auto *bw2 = bool_wrapper2.cast<BoolWrapper *>();
130+
bw1->set(true);
131+
bw2->set(true);
132+
});
133+
134+
std::thread t2([&]() {
135+
// This thread will wait until the first thread has entered the critical section due to
136+
// the barrier.
137+
barrier.arrive_and_wait();
138+
{
139+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
140+
// before entering the critical section. Otherwise, the critical section
141+
// will cause a segmentation fault.
142+
py::gil_scoped_acquire ensure_tstate{};
143+
// Enter the critical section with the same object as the first thread.
144+
py::scoped_critical_section lock{bool_wrapper1};
145+
// At this point, the critical section is released by the first thread, the value
146+
// is set to true.
147+
auto *bw1 = bool_wrapper1.cast<BoolWrapper *>();
148+
output.first = bw1->get();
149+
}
150+
});
151+
152+
std::thread t3([&]() {
153+
// This thread will wait until the first thread has entered the critical section due to
154+
// the barrier.
155+
barrier.arrive_and_wait();
156+
{
157+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
158+
// before entering the critical section. Otherwise, the critical section
159+
// will cause a segmentation fault.
160+
py::gil_scoped_acquire ensure_tstate{};
161+
// Enter the critical section with the same object as the first thread.
162+
py::scoped_critical_section lock{bool_wrapper2};
163+
// At this point, the critical section is released by the first thread, the value
164+
// is set to true.
165+
auto *bw2 = bool_wrapper2.cast<BoolWrapper *>();
166+
output.second = bw2->get();
167+
}
168+
});
169+
170+
t1.join();
171+
t2.join();
172+
t3.join();
173+
}
174+
175+
if (!output.first || !output.second) {
176+
throw std::runtime_error(
177+
"Scoped critical section test with two objects failed: output is false");
178+
}
179+
}
180+
181+
void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &cls) {
182+
auto barrier = std::barrier(2);
183+
auto bool_wrapper = cls(false);
184+
bool output = false;
185+
186+
{
187+
// Release the GIL to allow run threads in parallel.
188+
py::gil_scoped_release gil_release{};
189+
190+
std::thread t1([&]() {
191+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
192+
// before entering the critical section. Otherwise, the critical section
193+
// will cause a segmentation fault.
194+
py::gil_scoped_acquire ensure_tstate{};
195+
// Enter the critical section with the same object as the second thread.
196+
py::scoped_critical_section lock{bool_wrapper, bool_wrapper}; // same object used here
197+
// At this point, the object is locked by this thread via the scoped_critical_section.
198+
// This barrier will ensure that the second thread waits until this thread has released
199+
// the critical section before proceeding.
200+
barrier.arrive_and_wait();
201+
// Sleep for a short time to simulate some work in the critical section.
202+
// This sleep is necessary to test the locking mechanism properly.
203+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
204+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
205+
bw->set(true);
206+
});
207+
208+
std::thread t2([&]() {
209+
// This thread will wait until the first thread has entered the critical section due to
210+
// the barrier.
211+
barrier.arrive_and_wait();
212+
{
213+
// Use gil_scoped_acquire to ensure we have a valid Python thread state
214+
// before entering the critical section. Otherwise, the critical section
215+
// will cause a segmentation fault.
216+
py::gil_scoped_acquire ensure_tstate{};
217+
// Enter the critical section with the same object as the first thread.
218+
py::scoped_critical_section lock{bool_wrapper};
219+
// At this point, the critical section is released by the first thread, the value
220+
// is set to true.
221+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
222+
output = bw->get();
223+
}
224+
});
225+
226+
t1.join();
227+
t2.join();
228+
}
229+
230+
if (!output) {
231+
throw std::runtime_error(
232+
"Scoped critical section test with same object failed: output is false");
233+
}
234+
}
235+
236+
#else
237+
238+
void test_scoped_critical_section(const py::handle &) {}
239+
void test_scoped_critical_section2(const py::handle &) {}
240+
void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &) {}
241+
242+
#endif
243+
244+
} // namespace test_scoped_critical_section_ns
245+
246+
TEST_SUBMODULE(scoped_critical_section, m) {
247+
using namespace test_scoped_critical_section_ns;
248+
249+
m.def("test_one_nullptr", test_one_nullptr);
250+
m.def("test_two_nullptrs", test_two_nullptrs);
251+
m.def("test_first_nullptr", test_first_nullptr);
252+
m.def("test_second_nullptr", test_second_nullptr);
253+
254+
auto BoolWrapperClass = py::class_<BoolWrapper>(m, "BoolWrapper")
255+
.def(py::init<bool>())
256+
.def("get", &BoolWrapper::get)
257+
.def("set", &BoolWrapper::set);
258+
auto BoolWrapperHandle = py::handle(BoolWrapperClass);
259+
(void) BoolWrapperHandle.ptr(); // suppress unused variable warning
260+
261+
m.attr("has_barrier") =
262+
#ifdef PYBIND11_HAS_BARRIER
263+
true;
264+
#else
265+
false;
266+
#endif
267+
268+
m.def("test_scoped_critical_section",
269+
[BoolWrapperHandle]() -> void { test_scoped_critical_section(BoolWrapperHandle); });
270+
m.def("test_scoped_critical_section2",
271+
[BoolWrapperHandle]() -> void { test_scoped_critical_section2(BoolWrapperHandle); });
272+
m.def("test_scoped_critical_section2_same_object_no_deadlock", [BoolWrapperHandle]() -> void {
273+
test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperHandle);
274+
});
275+
}

tests/test_scoped_critical_section.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from pybind11_tests import scoped_critical_section as m
6+
7+
8+
def test_nullptr_combinations():
9+
m.test_one_nullptr()
10+
m.test_two_nullptrs()
11+
m.test_first_nullptr()
12+
m.test_second_nullptr()
13+
14+
15+
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
16+
def test_scoped_critical_section() -> None:
17+
for _ in range(64):
18+
m.test_scoped_critical_section()
19+
20+
21+
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
22+
def test_scoped_critical_section2() -> None:
23+
for _ in range(64):
24+
m.test_scoped_critical_section2()
25+
26+
27+
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
28+
def test_scoped_critical_section2_same_object_no_deadlock() -> None:
29+
for _ in range(64):
30+
m.test_scoped_critical_section2_same_object_no_deadlock()

0 commit comments

Comments
 (0)