Skip to content

Commit 4647efc

Browse files
committed
improved ndarray conversion for JAX (fixes issue #729)
1 parent c1be430 commit 4647efc

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

src/nb_ndarray.cpp

+81-73
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ struct ndarray_handle {
2323
bool ro;
2424
};
2525

26+
static void ndarray_capsule_destructor(PyObject *o) {
27+
error_scope scope; // temporarily save any existing errors
28+
managed_dltensor *mt =
29+
(managed_dltensor *) PyCapsule_GetPointer(o, "dltensor");
30+
31+
if (mt)
32+
ndarray_dec_ref((ndarray_handle *) mt->manager_ctx);
33+
else
34+
PyErr_Clear();
35+
}
36+
2637
static void nb_ndarray_dealloc(PyObject *self) {
2738
PyTypeObject *tp = Py_TYPE(self);
2839
ndarray_dec_ref(((nb_ndarray *) self)->th);
@@ -123,12 +134,52 @@ static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) {
123134
PyMem_Free(view->strides);
124135
}
125136

137+
138+
static PyObject *nb_ndarray_dlpack(PyObject *self, PyTypeObject *,
139+
PyObject *const *, Py_ssize_t ,
140+
PyObject *) {
141+
nb_ndarray *self_nd = (nb_ndarray *) self;
142+
ndarray_handle *th = self_nd->th;
143+
144+
PyObject *r =
145+
PyCapsule_New(th->ndarray, "dltensor", ndarray_capsule_destructor);
146+
if (r)
147+
ndarray_inc_ref(th);
148+
return r;
149+
}
150+
151+
static PyObject *nb_ndarray_dlpack_device(PyObject *self, PyTypeObject *,
152+
PyObject *const *, Py_ssize_t ,
153+
PyObject *) {
154+
nb_ndarray *self_nd = (nb_ndarray *) self;
155+
dlpack::dltensor &t = self_nd->th->ndarray->dltensor;
156+
PyObject *r = PyTuple_New(2);
157+
PyObject *r0 = PyLong_FromLong(t.device.device_type);
158+
PyObject *r1 = PyLong_FromLong(t.device.device_id);
159+
if (!r || !r0 || !r1) {
160+
Py_XDECREF(r);
161+
Py_XDECREF(r0);
162+
Py_XDECREF(r1);
163+
return nullptr;
164+
}
165+
NB_TUPLE_SET_ITEM(r, 0, r0);
166+
NB_TUPLE_SET_ITEM(r, 1, r1);
167+
return r;
168+
}
169+
170+
static PyMethodDef nb_ndarray_members[] = {
171+
{ "__dlpack__", (PyCFunction) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr },
172+
{ "__dlpack_device__", (PyCFunction) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr },
173+
{ nullptr, nullptr, 0, nullptr }
174+
};
175+
126176
static PyTypeObject *nd_ndarray_tp() noexcept {
127177
PyTypeObject *tp = internals->nb_ndarray;
128178

129179
if (NB_UNLIKELY(!tp)) {
130180
PyType_Slot slots[] = {
131181
{ Py_tp_dealloc, (void *) nb_ndarray_dealloc },
182+
{ Py_tp_methods, (void *) nb_ndarray_members },
132183
#if PY_VERSION_HEX >= 0x03090000
133184
{ Py_bf_getbuffer, (void *) nd_ndarray_tpbuffer },
134185
{ Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer },
@@ -649,17 +700,6 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
649700
return result.release();
650701
}
651702

652-
static void ndarray_capsule_destructor(PyObject *o) {
653-
error_scope scope; // temporarily save any existing errors
654-
managed_dltensor *mt =
655-
(managed_dltensor *) PyCapsule_GetPointer(o, "dltensor");
656-
657-
if (mt)
658-
ndarray_dec_ref((ndarray_handle *) mt->manager_ctx);
659-
else
660-
PyErr_Clear();
661-
}
662-
663703
PyObject *ndarray_export(ndarray_handle *th, int framework,
664704
rv_policy policy, cleanup_list *cleanup) noexcept {
665705
if (!th)
@@ -706,79 +746,47 @@ PyObject *ndarray_export(ndarray_handle *th, int framework,
706746
}
707747
}
708748

709-
if (framework == numpy::value) {
710-
try {
711-
nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp());
712-
if (!h)
713-
return nullptr;
714-
h->th = th;
715-
ndarray_inc_ref(th);
716-
717-
object o = steal((PyObject *) h);
718-
return module_::import_("numpy")
719-
.attr("array")(o, arg("copy") = copy)
720-
.release()
721-
.ptr();
722-
} catch (const std::exception &e) {
723-
PyErr_Format(PyExc_RuntimeError,
724-
"nanobind::detail::ndarray_export(): could not "
725-
"convert ndarray to NumPy array: %s", e.what());
726-
return nullptr;
727-
}
728-
}
729-
730-
object package;
731-
try {
732-
switch (framework) {
733-
case no_framework::value:
734-
break;
735-
736-
case pytorch::value:
737-
package = module_::import_("torch.utils.dlpack");
738-
break;
739-
740-
case tensorflow::value:
741-
package = module_::import_("tensorflow.experimental.dlpack");
742-
break;
743-
744-
case jax::value:
745-
package = module_::import_("jax.dlpack");
746-
break;
747-
748-
case cupy::value:
749-
package = module_::import_("cupy");
750-
break;
751-
752-
default:
753-
check(false, "nanobind::detail::ndarray_export(): unknown "
754-
"framework specified!");
755-
}
756-
} catch (const std::exception &e) {
757-
PyErr_Format(PyExc_RuntimeError,
758-
"nanobind::detail::ndarray_export(): could not import ndarray "
759-
"framework: %s", e.what());
760-
return nullptr;
761-
}
762-
763749
object o;
764750
if (copy && framework == no_framework::value && th->self) {
765751
o = borrow(th->self);
752+
} else if (framework == numpy::value || framework == jax::value) {
753+
nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp());
754+
if (!h)
755+
return nullptr;
756+
h->th = th;
757+
ndarray_inc_ref(th);
758+
o = steal((PyObject *) h);
766759
} else {
767760
o = steal(PyCapsule_New(th->ndarray, "dltensor",
768761
ndarray_capsule_destructor));
769762
ndarray_inc_ref(th);
770763
}
771764

765+
try {
766+
if (framework == numpy::value) {
767+
return module_::import_("numpy")
768+
.attr("array")(o, arg("copy") = copy)
769+
.release()
770+
.ptr();
771+
} else {
772+
const char *pkg_name;
773+
switch (framework) {
774+
case pytorch::value: pkg_name = "torch.utils.dlpack"; break;
775+
case tensorflow::value: pkg_name = "tensorflow.experimental.dlpack"; break;
776+
case jax::value: pkg_name = "jax.dlpack"; break;
777+
case cupy::value: pkg_name = "cupy"; break;
778+
default: pkg_name = nullptr;
779+
}
772780

773-
if (package.is_valid()) {
774-
try {
775-
o = package.attr("from_dlpack")(o);
776-
} catch (const std::exception &e) {
777-
PyErr_Format(PyExc_RuntimeError,
778-
"nanobind::detail::ndarray_export(): could not "
779-
"import ndarray: %s", e.what());
780-
return nullptr;
781+
if (pkg_name)
782+
o = module_::import_(pkg_name).attr("from_dlpack")(o);
781783
}
784+
} catch (const std::exception &e) {
785+
PyErr_Format(PyExc_RuntimeError,
786+
"nanobind::detail::ndarray_export(): could not "
787+
"import ndarray: %s",
788+
e.what());
789+
return nullptr;
782790
}
783791

784792
if (copy) {

tests/py_stub_test.pyi

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class AClass:
1515
@staticmethod
1616
def static_method(x): ...
1717

18-
@staticmethod
19-
def class_method(x): ...
18+
@classmethod
19+
def class_method(cls, x): ...
2020

2121
@overload
2222
def overloaded(self, x: int) -> None:

0 commit comments

Comments
 (0)