Skip to content

Commit 5f4815e

Browse files
committed
Make API globals thread safe using atomics
While the GIL is held when the API pointer is updated, this can still race with other threads checking the current value of the API pointer (without holding the GIL) and should therefore using atomics. The loads and stores are performed using acquire-release semantics as we want to dereference the pointer and hence any stores to the referenced memory need to be visible to us. The get function should also be unsafe as the offset it uses cannot be verified which might create an invalid pointer invoking undefined behaviour as per the contract of pointer::offset. Finally, the initialization code is moved into a separate cold function to improve code locality for the fast path.
1 parent 615d5c3 commit 5f4815e

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

src/npyffi/array.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
use libc::FILE;
33
use pyo3::ffi::{self, PyObject, PyTypeObject};
44
use std::os::raw::*;
5-
use std::{cell::Cell, ptr};
5+
use std::ptr::null_mut;
6+
use std::sync::atomic::{AtomicPtr, Ordering};
67

78
use crate::npyffi::*;
89

@@ -12,7 +13,7 @@ const CAPSULE_NAME: &str = "_ARRAY_API";
1213
/// A global variable which stores a ['capsule'](https://docs.python.org/3/c-api/capsule.html)
1314
/// pointer to [Numpy Array API](https://numpy.org/doc/stable/reference/c-api/array.html).
1415
///
15-
/// You can acceess raw c APIs via this variable and its Deref implementation.
16+
/// You can acceess raw C APIs via this variable.
1617
///
1718
/// See [PyArrayAPI](struct.PyArrayAPI.html) for what methods you can use via this variable.
1819
///
@@ -31,28 +32,35 @@ pub static PY_ARRAY_API: PyArrayAPI = PyArrayAPI::new();
3132

3233
/// See [PY_ARRAY_API] for more.
3334
pub struct PyArrayAPI {
34-
api: Cell<*const *const c_void>,
35+
api: AtomicPtr<*const c_void>,
3536
}
3637

3738
impl PyArrayAPI {
3839
const fn new() -> Self {
3940
Self {
40-
api: Cell::new(ptr::null_mut()),
41+
api: AtomicPtr::new(null_mut()),
4142
}
4243
}
43-
fn get(&self, offset: isize) -> *const *const c_void {
44-
if self.api.get().is_null() {
45-
Python::with_gil(|py| {
46-
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
47-
self.api.set(api);
48-
});
44+
#[cold]
45+
fn init(&self) -> *const *const c_void {
46+
Python::with_gil(|py| {
47+
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
48+
if api.is_null() {
49+
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
50+
self.api.store(api as *mut _, Ordering::Release);
51+
}
52+
api
53+
})
54+
}
55+
unsafe fn get(&self, offset: isize) -> *const *const c_void {
56+
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
57+
if api.is_null() {
58+
api = self.init();
4959
}
50-
unsafe { self.api.get().offset(offset) }
60+
api.offset(offset)
5161
}
5262
}
5363

54-
unsafe impl Sync for PyArrayAPI {}
55-
5664
impl PyArrayAPI {
5765
impl_api![0; PyArray_GetNDArrayCVersion() -> c_uint];
5866
impl_api![40; PyArray_SetNumericOps(dict: *mut PyObject) -> c_int];

src/npyffi/ufunc.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
//! Low-Level binding for [UFunc API](https://numpy.org/doc/stable/reference/c-api/ufunc.html)
22
33
use std::os::raw::*;
4-
use std::{cell::Cell, ptr};
4+
use std::ptr::null_mut;
5+
use std::sync::atomic::{AtomicPtr, Ordering};
56

67
use pyo3::ffi::PyObject;
78
use pyo3::Python;
@@ -18,28 +19,35 @@ const CAPSULE_NAME: &str = "_UFUNC_API";
1819
pub static PY_UFUNC_API: PyUFuncAPI = PyUFuncAPI::new();
1920

2021
pub struct PyUFuncAPI {
21-
api: Cell<*const *const c_void>,
22+
api: AtomicPtr<*const c_void>,
2223
}
2324

2425
impl PyUFuncAPI {
2526
const fn new() -> Self {
2627
Self {
27-
api: Cell::new(ptr::null_mut()),
28+
api: AtomicPtr::new(null_mut()),
2829
}
2930
}
30-
fn get(&self, offset: isize) -> *const *const c_void {
31-
if self.api.get().is_null() {
32-
Python::with_gil(|py| {
33-
let api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
34-
self.api.set(api);
35-
});
31+
#[cold]
32+
fn init(&self) -> *const *const c_void {
33+
Python::with_gil(|py| {
34+
let mut api = self.api.load(Ordering::Relaxed) as *const *const c_void;
35+
if api.is_null() {
36+
api = get_numpy_api(py, MOD_NAME, CAPSULE_NAME);
37+
self.api.store(api as *mut _, Ordering::Release);
38+
}
39+
api
40+
})
41+
}
42+
unsafe fn get(&self, offset: isize) -> *const *const c_void {
43+
let mut api = self.api.load(Ordering::Acquire) as *const *const c_void;
44+
if api.is_null() {
45+
api = self.init();
3646
}
37-
unsafe { self.api.get().offset(offset) }
47+
api.offset(offset)
3848
}
3949
}
4050

41-
unsafe impl Sync for PyUFuncAPI {}
42-
4351
impl PyUFuncAPI {
4452
impl_api![1; PyUFunc_FromFuncAndData(func: *mut PyUFuncGenericFunction, data: *mut *mut c_void, types: *mut c_char, ntypes: c_int, nin: c_int, nout: c_int, identity: c_int, name: *const c_char, doc: *const c_char, unused: c_int) -> *mut PyObject];
4553
impl_api![2; PyUFunc_RegisterLoopForType(ufunc: *mut PyUFuncObject, usertype: c_int, function: PyUFuncGenericFunction, arg_types: *mut c_int, data: *mut c_void) -> c_int];

0 commit comments

Comments
 (0)