@@ -23,6 +23,17 @@ struct ndarray_handle {
23
23
bool ro;
24
24
};
25
25
26
+ static void ndarray_capsule_destructor (PyObject *o) {
27
+ error_scope scope; // temporarily save any existing errors
28
+ managed_dltensor *mt =
29
+ (managed_dltensor *) PyCapsule_GetPointer (o, " dltensor" );
30
+
31
+ if (mt)
32
+ ndarray_dec_ref ((ndarray_handle *) mt->manager_ctx );
33
+ else
34
+ PyErr_Clear ();
35
+ }
36
+
26
37
static void nb_ndarray_dealloc (PyObject *self) {
27
38
PyTypeObject *tp = Py_TYPE (self);
28
39
ndarray_dec_ref (((nb_ndarray *) self)->th );
@@ -123,12 +134,52 @@ static void nb_ndarray_releasebuffer(PyObject *, Py_buffer *view) {
123
134
PyMem_Free (view->strides );
124
135
}
125
136
137
+
138
+ static PyObject *nb_ndarray_dlpack (PyObject *self, PyTypeObject *,
139
+ PyObject *const *, Py_ssize_t ,
140
+ PyObject *) {
141
+ nb_ndarray *self_nd = (nb_ndarray *) self;
142
+ ndarray_handle *th = self_nd->th ;
143
+
144
+ PyObject *r =
145
+ PyCapsule_New (th->ndarray , " dltensor" , ndarray_capsule_destructor);
146
+ if (r)
147
+ ndarray_inc_ref (th);
148
+ return r;
149
+ }
150
+
151
+ static PyObject *nb_ndarray_dlpack_device (PyObject *self, PyTypeObject *,
152
+ PyObject *const *, Py_ssize_t ,
153
+ PyObject *) {
154
+ nb_ndarray *self_nd = (nb_ndarray *) self;
155
+ dlpack::dltensor &t = self_nd->th ->ndarray ->dltensor ;
156
+ PyObject *r = PyTuple_New (2 );
157
+ PyObject *r0 = PyLong_FromLong (t.device .device_type );
158
+ PyObject *r1 = PyLong_FromLong (t.device .device_id );
159
+ if (!r || !r0 || !r1) {
160
+ Py_XDECREF (r);
161
+ Py_XDECREF (r0);
162
+ Py_XDECREF (r1);
163
+ return nullptr ;
164
+ }
165
+ NB_TUPLE_SET_ITEM (r, 0 , r0);
166
+ NB_TUPLE_SET_ITEM (r, 1 , r1);
167
+ return r;
168
+ }
169
+
170
+ static PyMethodDef nb_ndarray_members[] = {
171
+ { " __dlpack__" , (PyCFunction) nb_ndarray_dlpack, METH_FASTCALL | METH_KEYWORDS, nullptr },
172
+ { " __dlpack_device__" , (PyCFunction) nb_ndarray_dlpack_device, METH_FASTCALL | METH_KEYWORDS, nullptr },
173
+ { nullptr , nullptr , 0 , nullptr }
174
+ };
175
+
126
176
static PyTypeObject *nd_ndarray_tp () noexcept {
127
177
PyTypeObject *tp = internals->nb_ndarray ;
128
178
129
179
if (NB_UNLIKELY (!tp)) {
130
180
PyType_Slot slots[] = {
131
181
{ Py_tp_dealloc, (void *) nb_ndarray_dealloc },
182
+ { Py_tp_methods, (void *) nb_ndarray_members },
132
183
#if PY_VERSION_HEX >= 0x03090000
133
184
{ Py_bf_getbuffer, (void *) nd_ndarray_tpbuffer },
134
185
{ Py_bf_releasebuffer, (void *) nb_ndarray_releasebuffer },
@@ -649,17 +700,6 @@ ndarray_handle *ndarray_create(void *value, size_t ndim, const size_t *shape_in,
649
700
return result.release ();
650
701
}
651
702
652
- static void ndarray_capsule_destructor (PyObject *o) {
653
- error_scope scope; // temporarily save any existing errors
654
- managed_dltensor *mt =
655
- (managed_dltensor *) PyCapsule_GetPointer (o, " dltensor" );
656
-
657
- if (mt)
658
- ndarray_dec_ref ((ndarray_handle *) mt->manager_ctx );
659
- else
660
- PyErr_Clear ();
661
- }
662
-
663
703
PyObject *ndarray_export (ndarray_handle *th, int framework,
664
704
rv_policy policy, cleanup_list *cleanup) noexcept {
665
705
if (!th)
@@ -706,79 +746,47 @@ PyObject *ndarray_export(ndarray_handle *th, int framework,
706
746
}
707
747
}
708
748
709
- if (framework == numpy::value) {
710
- try {
711
- nb_ndarray *h = PyObject_New (nb_ndarray, nd_ndarray_tp ());
712
- if (!h)
713
- return nullptr ;
714
- h->th = th;
715
- ndarray_inc_ref (th);
716
-
717
- object o = steal ((PyObject *) h);
718
- return module_::import_ (" numpy" )
719
- .attr (" array" )(o, arg (" copy" ) = copy)
720
- .release ()
721
- .ptr ();
722
- } catch (const std::exception &e) {
723
- PyErr_Format (PyExc_RuntimeError,
724
- " nanobind::detail::ndarray_export(): could not "
725
- " convert ndarray to NumPy array: %s" , e.what ());
726
- return nullptr ;
727
- }
728
- }
729
-
730
- object package;
731
- try {
732
- switch (framework) {
733
- case no_framework::value:
734
- break ;
735
-
736
- case pytorch::value:
737
- package = module_::import_ (" torch.utils.dlpack" );
738
- break ;
739
-
740
- case tensorflow::value:
741
- package = module_::import_ (" tensorflow.experimental.dlpack" );
742
- break ;
743
-
744
- case jax::value:
745
- package = module_::import_ (" jax.dlpack" );
746
- break ;
747
-
748
- case cupy::value:
749
- package = module_::import_ (" cupy" );
750
- break ;
751
-
752
- default :
753
- check (false , " nanobind::detail::ndarray_export(): unknown "
754
- " framework specified!" );
755
- }
756
- } catch (const std::exception &e) {
757
- PyErr_Format (PyExc_RuntimeError,
758
- " nanobind::detail::ndarray_export(): could not import ndarray "
759
- " framework: %s" , e.what ());
760
- return nullptr ;
761
- }
762
-
763
749
object o;
764
750
if (copy && framework == no_framework::value && th->self ) {
765
751
o = borrow (th->self );
752
+ } else if (framework == numpy::value || framework == jax::value) {
753
+ nb_ndarray *h = PyObject_New (nb_ndarray, nd_ndarray_tp ());
754
+ if (!h)
755
+ return nullptr ;
756
+ h->th = th;
757
+ ndarray_inc_ref (th);
758
+ o = steal ((PyObject *) h);
766
759
} else {
767
760
o = steal (PyCapsule_New (th->ndarray , " dltensor" ,
768
761
ndarray_capsule_destructor));
769
762
ndarray_inc_ref (th);
770
763
}
771
764
765
+ try {
766
+ if (framework == numpy::value) {
767
+ return module_::import_ (" numpy" )
768
+ .attr (" array" )(o, arg (" copy" ) = copy)
769
+ .release ()
770
+ .ptr ();
771
+ } else {
772
+ const char *pkg_name;
773
+ switch (framework) {
774
+ case pytorch::value: pkg_name = " torch.utils.dlpack" ; break ;
775
+ case tensorflow::value: pkg_name = " tensorflow.experimental.dlpack" ; break ;
776
+ case jax::value: pkg_name = " jax.dlpack" ; break ;
777
+ case cupy::value: pkg_name = " cupy" ; break ;
778
+ default : pkg_name = nullptr ;
779
+ }
772
780
773
- if (package.is_valid ()) {
774
- try {
775
- o = package.attr (" from_dlpack" )(o);
776
- } catch (const std::exception &e) {
777
- PyErr_Format (PyExc_RuntimeError,
778
- " nanobind::detail::ndarray_export(): could not "
779
- " import ndarray: %s" , e.what ());
780
- return nullptr ;
781
+ if (pkg_name)
782
+ o = module_::import_ (pkg_name).attr (" from_dlpack" )(o);
781
783
}
784
+ } catch (const std::exception &e) {
785
+ PyErr_Format (PyExc_RuntimeError,
786
+ " nanobind::detail::ndarray_export(): could not "
787
+ " import ndarray: %s" ,
788
+ e.what ());
789
+ return nullptr ;
782
790
}
783
791
784
792
if (copy) {
0 commit comments