Skip to content

Commit 2628624

Browse files
committed
Add PyArray_Type_Global
To adapt to PyO3 interface
1 parent 2e1b069 commit 2628624

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

Cargo.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,5 @@ num-complex = "0.1"
1515
ndarray = "0.10"
1616

1717
[dependencies.pyo3]
18-
git = "http://github.com/termoshtt/pyo3"
19-
branch = "pyobject_macros"
20-
version = "0.2"
18+
git = "https://github.com/PyO3/pyo3.git"
19+
rev = "c22bec6124ab68f47a7f28550931e3060f89071b"

src/array.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
use ndarray::*;
44
use npyffi;
5-
use pyo3::ffi;
65
use pyo3::*;
76

87
use std::os::raw::c_void;
@@ -13,8 +12,7 @@ use super::*;
1312

1413
/// Untyped safe interface for NumPy ndarray.
1514
pub struct PyArray(PyObject);
16-
17-
pyobject_native_type!(PyArray, npyffi::PyArray_Type, npyffi::PyArray_Check);
15+
pyobject_native_type!(PyArray, npyffi::PyArray_Type_Global, npyffi::PyArray_Check);
1816

1917
impl PyArray {
2018
pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {

src/npyffi/array.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ pub struct PyArrayModule<'py> {
2525
api: *const *const c_void,
2626
}
2727

28+
#[allow(non_upper_case_globals)]
29+
pub static mut PyArray_Type_Global: PyTypeObject = ffi::PyTypeObject_INIT;
30+
2831
impl<'py> Deref for PyArrayModule<'py> {
2932
type Target = PyModule;
3033
fn deref(&self) -> &Self::Target {
@@ -50,10 +53,11 @@ impl<'py> PyArrayModule<'py> {
5053
let api = unsafe {
5154
ffi::PyCapsule_GetPointer(c_api.as_ptr(), null_mut()) as *const *const c_void
5255
};
53-
Ok(Self {
54-
numpy: numpy,
55-
api: api,
56-
})
56+
let mod_ = PyArrayModule { numpy, api };
57+
unsafe {
58+
PyArray_Type_Global = *mod_.get_type_object(ArrayType::PyArray_Type);
59+
}
60+
Ok(mod_)
5761
}
5862

5963
pyarray_api![0; PyArray_GetNDArrayCVersion() -> c_uint];
@@ -373,11 +377,13 @@ impl_array_type!(
373377
);
374378

375379
#[allow(non_snake_case)]
376-
pub unsafe fn PyArray_Check(np: &PyArrayModule, op: *mut PyObject) -> c_int {
377-
ffi::PyObject_TypeCheck(op, np.get_type_object(ArrayType::PyArray_Type))
380+
pub unsafe fn PyArray_Check(op: *mut PyObject) -> c_int {
381+
let typeobj_ptr: *mut PyTypeObject = &mut PyArray_Type_Global;
382+
ffi::PyObject_TypeCheck(op, typeobj_ptr)
378383
}
379384

380385
#[allow(non_snake_case)]
381-
pub unsafe fn PyArray_CheckExact(np: &PyArrayModule, op: *mut PyObject) -> c_int {
382-
(ffi::Py_TYPE(op) == np.get_type_object(ArrayType::PyArray_Type)) as c_int
386+
pub unsafe fn PyArray_CheckExact(op: *mut PyObject) -> c_int {
387+
let typeobj_ptr: *mut _ = &mut PyArray_Type_Global;
388+
(ffi::Py_TYPE(op) == typeobj_ptr) as c_int
383389
}

0 commit comments

Comments
 (0)