Skip to content

Commit 3c386ec

Browse files
authored
Merge pull request #36 from eriknw/support_udts
Update Cython code to support UDTs
2 parents b5d5144 + 556538d commit 3c386ec

File tree

2 files changed

+34
-23
lines changed

2 files changed

+34
-23
lines changed

suitesparse_graphblas/utils.pxd

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
from libc.stdint cimport uint64_t
2-
from numpy cimport ndarray
2+
from numpy cimport ndarray, npy_intp, dtype as dtype_t
33

44

55
cdef extern from "numpy/arrayobject.h" nogil:
66
# These aren't public (i.e., "extern"), but other projects use them too
7-
void *PyDataMem_NEW(size_t)
8-
void *PyDataMem_NEW_ZEROED(size_t, size_t)
9-
void *PyDataMem_RENEW(void *, size_t)
10-
void PyDataMem_FREE(void *)
7+
void *PyDataMem_NEW(size_t size)
8+
void *PyDataMem_NEW_ZEROED(size_t nmemb, size_t size)
9+
void *PyDataMem_RENEW(void *ptr, size_t size)
10+
void PyDataMem_FREE(void *ptr)
1111
# These are available in newer Cython versions
12-
void PyArray_ENABLEFLAGS(ndarray, int flags)
13-
void PyArray_CLEARFLAGS(ndarray, int flags)
12+
void PyArray_ENABLEFLAGS(ndarray array, int flags)
13+
void PyArray_CLEARFLAGS(ndarray array, int flags)
14+
# Not exposed by Cython (b/c it steals a reference from dtype)
15+
ndarray PyArray_NewFromDescr(
16+
type subtype, dtype_t dtype, int nd, npy_intp *dims, npy_intp *strides, void *data, int flags, object obj
17+
)
1418

1519
ctypedef enum GrB_Mode:
1620
GrB_NONBLOCKING
@@ -24,11 +28,13 @@ ctypedef uint64_t (*GxB_init)(
2428
void (*user_free_function)(void *),
2529
)
2630

27-
cpdef int call_gxb_init(ffi, lib, int mode)
31+
cpdef int call_gxb_init(object ffi, object lib, int mode)
2832

29-
cpdef ndarray claim_buffer(ffi, cdata, size_t size, dtype)
33+
cpdef ndarray claim_buffer(object ffi, object cdata, size_t size, dtype_t dtype)
3034

31-
cpdef ndarray claim_buffer_2d(ffi, cdata, size_t cdata_size, size_t nrows, size_t ncols, dtype, bint is_c_order)
35+
cpdef ndarray claim_buffer_2d(
36+
object ffi, object cdata, size_t cdata_size, size_t nrows, size_t ncols, dtype_t dtype, bint is_c_order
37+
)
3238

3339
cpdef unclaim_buffer(ndarray array)
3440

suitesparse_graphblas/utils.pyx

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
import numpy as np
2+
from cpython.ref cimport Py_INCREF
23
from libc.stdint cimport uintptr_t
34
from numpy cimport (
45
NPY_ARRAY_F_CONTIGUOUS,
56
NPY_ARRAY_OWNDATA,
67
NPY_ARRAY_WRITEABLE,
7-
PyArray_New,
8-
PyArray_SimpleNewFromData,
98
import_array,
109
ndarray,
1110
npy_intp,
11+
dtype as dtype_t,
1212
)
1313

1414
import_array()
1515

16-
cpdef int call_gxb_init(ffi, lib, int mode):
16+
cpdef int call_gxb_init(object ffi, object lib, int mode):
1717
# We need to call `GxB_init`, but we didn't compile Cython against GraphBLAS. So, we get it from cffi.
1818
# Step 1: ffi.addressof(lib, "GxB_init")
1919
# Return type: cffi.cdata object of a function pointer. Can't cast to int.
@@ -30,32 +30,37 @@ cpdef int call_gxb_init(ffi, lib, int mode):
3030
return func(<GrB_Mode>mode, PyDataMem_NEW, PyDataMem_NEW_ZEROED, PyDataMem_RENEW, PyDataMem_FREE)
3131

3232

33-
cpdef ndarray claim_buffer(ffi, cdata, size_t size, dtype):
33+
cpdef ndarray claim_buffer(object ffi, object cdata, size_t size, dtype_t dtype):
3434
cdef:
3535
npy_intp dims = size
3636
uintptr_t ptr = int(ffi.cast("uintptr_t", cdata))
37-
ndarray array = PyArray_SimpleNewFromData(1, &dims, dtype.num, <void*>ptr)
37+
ndarray array
38+
Py_INCREF(dtype)
39+
array = PyArray_NewFromDescr(
40+
ndarray, dtype, 1, &dims, NULL, <void*>ptr, NPY_ARRAY_WRITEABLE, <object>NULL
41+
)
3842
PyArray_ENABLEFLAGS(array, NPY_ARRAY_OWNDATA)
3943
return array
4044

4145

42-
cpdef ndarray claim_buffer_2d(ffi, cdata, size_t cdata_size, size_t nrows, size_t ncols, dtype, bint is_c_order):
46+
cpdef ndarray claim_buffer_2d(
47+
object ffi, object cdata, size_t cdata_size, size_t nrows, size_t ncols, dtype_t dtype, bint is_c_order
48+
):
4349
cdef:
4450
size_t size = nrows * ncols
4551
ndarray array
4652
uintptr_t ptr
4753
npy_intp dims[2]
54+
int flags = NPY_ARRAY_WRITEABLE
4855
if cdata_size == size:
4956
ptr = int(ffi.cast("uintptr_t", cdata))
5057
dims[0] = nrows
5158
dims[1] = ncols
52-
if is_c_order:
53-
array = PyArray_SimpleNewFromData(2, dims, dtype.num, <void*>ptr)
54-
else:
55-
array = PyArray_New(
56-
ndarray, 2, dims, dtype.num, NULL, <void*>ptr, -1,
57-
NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_WRITEABLE, <object>NULL
58-
)
59+
if not is_c_order:
60+
flags |= NPY_ARRAY_F_CONTIGUOUS
61+
array = PyArray_NewFromDescr(
62+
ndarray, dtype, 2, dims, NULL, <void*>ptr, flags, <object>NULL
63+
)
5964
PyArray_ENABLEFLAGS(array, NPY_ARRAY_OWNDATA)
6065
elif cdata_size > size: # pragma: no cover
6166
array = claim_buffer(ffi, cdata, cdata_size, dtype)

0 commit comments

Comments
 (0)