|
| 1 | +#include <pybind11/critical_section.h> |
| 2 | + |
| 3 | +#include "pybind11_tests.h" |
| 4 | + |
| 5 | +#include <atomic> |
| 6 | +#include <chrono> |
| 7 | +#include <thread> |
| 8 | +#include <utility> |
| 9 | + |
| 10 | +#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include(<barrier>) |
| 11 | +# define PYBIND11_HAS_BARRIER 1 |
| 12 | +# include <barrier> |
| 13 | +#endif |
| 14 | + |
| 15 | +namespace test_scoped_critical_section_ns { |
| 16 | + |
| 17 | +void test_one_nullptr() { py::scoped_critical_section lock{py::handle{}}; } |
| 18 | + |
| 19 | +void test_two_nullptrs() { py::scoped_critical_section lock{py::handle{}, py::handle{}}; } |
| 20 | + |
| 21 | +void test_first_nullptr() { |
| 22 | + py::dict d; |
| 23 | + py::scoped_critical_section lock{py::handle{}, d}; |
| 24 | +} |
| 25 | + |
| 26 | +void test_second_nullptr() { |
| 27 | + py::dict d; |
| 28 | + py::scoped_critical_section lock{d, py::handle{}}; |
| 29 | +} |
| 30 | + |
| 31 | +// Referenced test implementation: https://github.com/PyO3/pyo3/blob/v0.25.0/src/sync.rs#L874 |
| 32 | +class BoolWrapper { |
| 33 | +public: |
| 34 | + explicit BoolWrapper(bool value) : value_{value} {} |
| 35 | + bool get() const { return value_.load(std::memory_order_acquire); } |
| 36 | + void set(bool value) { value_.store(value, std::memory_order_release); } |
| 37 | + |
| 38 | +private: |
| 39 | + std::atomic_bool value_{false}; |
| 40 | +}; |
| 41 | + |
| 42 | +#if defined(PYBIND11_HAS_BARRIER) |
| 43 | + |
| 44 | +// Modifying the C/C++ members of a Python object from multiple threads requires a critical section |
| 45 | +// to ensure thread safety and data integrity. |
| 46 | +// These tests use a scoped critical section to ensure that the Python object is accessed in a |
| 47 | +// thread-safe manner. |
| 48 | + |
| 49 | +void test_scoped_critical_section(const py::handle &cls) { |
| 50 | + auto barrier = std::barrier(2); |
| 51 | + auto bool_wrapper = cls(false); |
| 52 | + bool output = false; |
| 53 | + |
| 54 | + { |
| 55 | + // Release the GIL to allow run threads in parallel. |
| 56 | + py::gil_scoped_release gil_release{}; |
| 57 | + |
| 58 | + std::thread t1([&]() { |
| 59 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 60 | + // before entering the critical section. Otherwise, the critical section |
| 61 | + // will cause a segmentation fault. |
| 62 | + py::gil_scoped_acquire ensure_tstate{}; |
| 63 | + // Enter the critical section with the same object as the second thread. |
| 64 | + py::scoped_critical_section lock{bool_wrapper}; |
| 65 | + // At this point, the object is locked by this thread via the scoped_critical_section. |
| 66 | + // This barrier will ensure that the second thread waits until this thread has released |
| 67 | + // the critical section before proceeding. |
| 68 | + barrier.arrive_and_wait(); |
| 69 | + // Sleep for a short time to simulate some work in the critical section. |
| 70 | + // This sleep is necessary to test the locking mechanism properly. |
| 71 | + std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
| 72 | + auto *bw = bool_wrapper.cast<BoolWrapper *>(); |
| 73 | + bw->set(true); |
| 74 | + }); |
| 75 | + |
| 76 | + std::thread t2([&]() { |
| 77 | + // This thread will wait until the first thread has entered the critical section due to |
| 78 | + // the barrier. |
| 79 | + barrier.arrive_and_wait(); |
| 80 | + { |
| 81 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 82 | + // before entering the critical section. Otherwise, the critical section |
| 83 | + // will cause a segmentation fault. |
| 84 | + py::gil_scoped_acquire ensure_tstate{}; |
| 85 | + // Enter the critical section with the same object as the first thread. |
| 86 | + py::scoped_critical_section lock{bool_wrapper}; |
| 87 | + // At this point, the critical section is released by the first thread, the value |
| 88 | + // is set to true. |
| 89 | + auto *bw = bool_wrapper.cast<BoolWrapper *>(); |
| 90 | + output = bw->get(); |
| 91 | + } |
| 92 | + }); |
| 93 | + |
| 94 | + t1.join(); |
| 95 | + t2.join(); |
| 96 | + } |
| 97 | + |
| 98 | + if (!output) { |
| 99 | + throw std::runtime_error("Scoped critical section test failed: output is false"); |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +void test_scoped_critical_section2(const py::handle &cls) { |
| 104 | + auto barrier = std::barrier(3); |
| 105 | + auto bool_wrapper1 = cls(false); |
| 106 | + auto bool_wrapper2 = cls(false); |
| 107 | + std::pair<bool, bool> output{false, false}; |
| 108 | + |
| 109 | + { |
| 110 | + // Release the GIL to allow run threads in parallel. |
| 111 | + py::gil_scoped_release gil_release{}; |
| 112 | + |
| 113 | + std::thread t1([&]() { |
| 114 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 115 | + // before entering the critical section. Otherwise, the critical section |
| 116 | + // will cause a segmentation fault. |
| 117 | + py::gil_scoped_acquire ensure_tstate{}; |
| 118 | + // Enter the critical section with two different objects. |
| 119 | + // This will ensure that the critical section is locked for both objects. |
| 120 | + py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2}; |
| 121 | + // At this point, objects are locked by this thread via the scoped_critical_section. |
| 122 | + // This barrier will ensure that other threads wait until this thread has released |
| 123 | + // the critical section before proceeding. |
| 124 | + barrier.arrive_and_wait(); |
| 125 | + // Sleep for a short time to simulate some work in the critical section. |
| 126 | + // This sleep is necessary to test the locking mechanism properly. |
| 127 | + std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
| 128 | + auto *bw1 = bool_wrapper1.cast<BoolWrapper *>(); |
| 129 | + auto *bw2 = bool_wrapper2.cast<BoolWrapper *>(); |
| 130 | + bw1->set(true); |
| 131 | + bw2->set(true); |
| 132 | + }); |
| 133 | + |
| 134 | + std::thread t2([&]() { |
| 135 | + // This thread will wait until the first thread has entered the critical section due to |
| 136 | + // the barrier. |
| 137 | + barrier.arrive_and_wait(); |
| 138 | + { |
| 139 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 140 | + // before entering the critical section. Otherwise, the critical section |
| 141 | + // will cause a segmentation fault. |
| 142 | + py::gil_scoped_acquire ensure_tstate{}; |
| 143 | + // Enter the critical section with the same object as the first thread. |
| 144 | + py::scoped_critical_section lock{bool_wrapper1}; |
| 145 | + // At this point, the critical section is released by the first thread, the value |
| 146 | + // is set to true. |
| 147 | + auto *bw1 = bool_wrapper1.cast<BoolWrapper *>(); |
| 148 | + output.first = bw1->get(); |
| 149 | + } |
| 150 | + }); |
| 151 | + |
| 152 | + std::thread t3([&]() { |
| 153 | + // This thread will wait until the first thread has entered the critical section due to |
| 154 | + // the barrier. |
| 155 | + barrier.arrive_and_wait(); |
| 156 | + { |
| 157 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 158 | + // before entering the critical section. Otherwise, the critical section |
| 159 | + // will cause a segmentation fault. |
| 160 | + py::gil_scoped_acquire ensure_tstate{}; |
| 161 | + // Enter the critical section with the same object as the first thread. |
| 162 | + py::scoped_critical_section lock{bool_wrapper2}; |
| 163 | + // At this point, the critical section is released by the first thread, the value |
| 164 | + // is set to true. |
| 165 | + auto *bw2 = bool_wrapper2.cast<BoolWrapper *>(); |
| 166 | + output.second = bw2->get(); |
| 167 | + } |
| 168 | + }); |
| 169 | + |
| 170 | + t1.join(); |
| 171 | + t2.join(); |
| 172 | + t3.join(); |
| 173 | + } |
| 174 | + |
| 175 | + if (!output.first || !output.second) { |
| 176 | + throw std::runtime_error( |
| 177 | + "Scoped critical section test with two objects failed: output is false"); |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &cls) { |
| 182 | + auto barrier = std::barrier(2); |
| 183 | + auto bool_wrapper = cls(false); |
| 184 | + bool output = false; |
| 185 | + |
| 186 | + { |
| 187 | + // Release the GIL to allow run threads in parallel. |
| 188 | + py::gil_scoped_release gil_release{}; |
| 189 | + |
| 190 | + std::thread t1([&]() { |
| 191 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 192 | + // before entering the critical section. Otherwise, the critical section |
| 193 | + // will cause a segmentation fault. |
| 194 | + py::gil_scoped_acquire ensure_tstate{}; |
| 195 | + // Enter the critical section with the same object as the second thread. |
| 196 | + py::scoped_critical_section lock{bool_wrapper, bool_wrapper}; // same object used here |
| 197 | + // At this point, the object is locked by this thread via the scoped_critical_section. |
| 198 | + // This barrier will ensure that the second thread waits until this thread has released |
| 199 | + // the critical section before proceeding. |
| 200 | + barrier.arrive_and_wait(); |
| 201 | + // Sleep for a short time to simulate some work in the critical section. |
| 202 | + // This sleep is necessary to test the locking mechanism properly. |
| 203 | + std::this_thread::sleep_for(std::chrono::milliseconds(10)); |
| 204 | + auto *bw = bool_wrapper.cast<BoolWrapper *>(); |
| 205 | + bw->set(true); |
| 206 | + }); |
| 207 | + |
| 208 | + std::thread t2([&]() { |
| 209 | + // This thread will wait until the first thread has entered the critical section due to |
| 210 | + // the barrier. |
| 211 | + barrier.arrive_and_wait(); |
| 212 | + { |
| 213 | + // Use gil_scoped_acquire to ensure we have a valid Python thread state |
| 214 | + // before entering the critical section. Otherwise, the critical section |
| 215 | + // will cause a segmentation fault. |
| 216 | + py::gil_scoped_acquire ensure_tstate{}; |
| 217 | + // Enter the critical section with the same object as the first thread. |
| 218 | + py::scoped_critical_section lock{bool_wrapper}; |
| 219 | + // At this point, the critical section is released by the first thread, the value |
| 220 | + // is set to true. |
| 221 | + auto *bw = bool_wrapper.cast<BoolWrapper *>(); |
| 222 | + output = bw->get(); |
| 223 | + } |
| 224 | + }); |
| 225 | + |
| 226 | + t1.join(); |
| 227 | + t2.join(); |
| 228 | + } |
| 229 | + |
| 230 | + if (!output) { |
| 231 | + throw std::runtime_error( |
| 232 | + "Scoped critical section test with same object failed: output is false"); |
| 233 | + } |
| 234 | +} |
| 235 | + |
| 236 | +#else |
| 237 | + |
| 238 | +void test_scoped_critical_section(const py::handle &) {} |
| 239 | +void test_scoped_critical_section2(const py::handle &) {} |
| 240 | +void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &) {} |
| 241 | + |
| 242 | +#endif |
| 243 | + |
| 244 | +} // namespace test_scoped_critical_section_ns |
| 245 | + |
| 246 | +TEST_SUBMODULE(scoped_critical_section, m) { |
| 247 | + using namespace test_scoped_critical_section_ns; |
| 248 | + |
| 249 | + m.def("test_one_nullptr", test_one_nullptr); |
| 250 | + m.def("test_two_nullptrs", test_two_nullptrs); |
| 251 | + m.def("test_first_nullptr", test_first_nullptr); |
| 252 | + m.def("test_second_nullptr", test_second_nullptr); |
| 253 | + |
| 254 | + auto BoolWrapperClass = py::class_<BoolWrapper>(m, "BoolWrapper") |
| 255 | + .def(py::init<bool>()) |
| 256 | + .def("get", &BoolWrapper::get) |
| 257 | + .def("set", &BoolWrapper::set); |
| 258 | + auto BoolWrapperHandle = py::handle(BoolWrapperClass); |
| 259 | + (void) BoolWrapperHandle.ptr(); // suppress unused variable warning |
| 260 | + |
| 261 | + m.attr("has_barrier") = |
| 262 | +#ifdef PYBIND11_HAS_BARRIER |
| 263 | + true; |
| 264 | +#else |
| 265 | + false; |
| 266 | +#endif |
| 267 | + |
| 268 | + m.def("test_scoped_critical_section", |
| 269 | + [BoolWrapperHandle]() -> void { test_scoped_critical_section(BoolWrapperHandle); }); |
| 270 | + m.def("test_scoped_critical_section2", |
| 271 | + [BoolWrapperHandle]() -> void { test_scoped_critical_section2(BoolWrapperHandle); }); |
| 272 | + m.def("test_scoped_critical_section2_same_object_no_deadlock", [BoolWrapperHandle]() -> void { |
| 273 | + test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperHandle); |
| 274 | + }); |
| 275 | +} |
0 commit comments