@@ -225,51 +225,74 @@ namespace pyAMReX
225
225
*/
226
226
227
227
228
- // DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
228
+ // DLPack v1.1 protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
229
229
// https://dmlc.github.io/dlpack/latest/
230
230
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
231
231
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
232
- .def (" __dlpack__" , [](Array4<T> const &a4, [[maybe_unused]] py::handle stream = py::none ()) {
232
+ .def (" __dlpack__" , [](
233
+ Array4<T> const &a4
234
+ /* TODO:
235
+ [[maybe_unused]] py::handle stream,
236
+ [[maybe_unused]] std::tuple<int, int> max_version,
237
+ [[maybe_unused]] std::tuple<DLDeviceType, int32_t> dl_device,
238
+ [[maybe_unused]] bool copy
239
+ */
240
+ )
241
+ {
233
242
// Allocate shape/strides arrays
234
243
constexpr int ndim = 4 ;
235
244
auto const len = length (a4);
236
- auto *shape = new int64_t [ndim]{a4. nComp (), len. z , len. y , len. x };
237
- auto *strides = new int64_t [ndim]{a4. nstride , a4. kstride , a4. jstride , 1 };
238
-
239
- // Construct DLTensor
240
- auto *dl_tensor = new DLManagedTensor;
241
- dl_tensor ->dl_tensor .data = const_cast <void *>(static_cast <const void *>(a4.dataPtr ()));
242
- dl_tensor ->dl_tensor .device = dlpack::detect_device_from_pointer (a4.dataPtr ());
243
- dl_tensor ->dl_tensor .ndim = ndim;
244
- dl_tensor ->dl_tensor .dtype = dlpack::get_dlpack_dtype<T>();
245
- dl_tensor ->dl_tensor .shape = shape ;
246
- dl_tensor ->dl_tensor .strides = strides ;
247
- dl_tensor ->dl_tensor .byte_offset = 0 ;
248
- dl_tensor ->manager_ctx = nullptr ;
249
- dl_tensor ->deleter = [](DLManagedTensor *self) {
245
+
246
+ // Construct DLManagedTensorVersioned (DLPack 1.1 standard)
247
+ auto *dl_mgt_tensor = new DLManagedTensorVersioned;
248
+ dl_mgt_tensor-> version = DLPackVersion{};
249
+ dl_mgt_tensor-> flags = 0 ; // No special flags
250
+ dl_mgt_tensor ->dl_tensor .data = const_cast <void *>(static_cast <const void *>(a4.dataPtr ()));
251
+ dl_mgt_tensor ->dl_tensor .device = dlpack::detect_device_from_pointer (a4.dataPtr ());
252
+ dl_mgt_tensor ->dl_tensor .ndim = ndim;
253
+ dl_mgt_tensor ->dl_tensor .dtype = dlpack::get_dlpack_dtype<T>();
254
+ dl_mgt_tensor ->dl_tensor .shape = new int64_t [ndim]{a4. nComp (), len. z , len. y , len. x } ;
255
+ dl_mgt_tensor ->dl_tensor .strides = new int64_t [ndim]{a4. nstride , a4. kstride , a4. jstride , 1 } ;
256
+ dl_mgt_tensor ->dl_tensor .byte_offset = 0 ;
257
+ dl_mgt_tensor ->manager_ctx = nullptr ; // TODO: we can increase/decrease the Python ref counter of the producer here
258
+ dl_mgt_tensor ->deleter = [](DLManagedTensorVersioned *self) {
250
259
delete[] self->dl_tensor .shape ;
251
260
delete[] self->dl_tensor .strides ;
252
261
delete self;
253
262
};
254
263
// Return as Python capsule
255
- return py::capsule (dl_tensor, " dltensor" , [](void * ptr) {
256
- auto * tensor = static_cast <DLManagedTensor*>(ptr);
257
- tensor->deleter (tensor);
258
- });
264
+ return py::capsule (
265
+ dl_mgt_tensor,
266
+ " dltensor_versioned" ,
267
+ /* [](void* ptr) {
268
+ auto* tensor = static_cast<DLManagedTensorVersioned*>(ptr);
269
+ tensor->deleter(tensor);
270
+ }*/
271
+ [](PyObject *capsule)
272
+ {
273
+ auto *p = static_cast <DLManagedTensorVersioned*>(
274
+ PyCapsule_GetPointer (capsule, " dltensor_versioned" ));
275
+ if (p && p->deleter )
276
+ p->deleter (p);
277
+ }
278
+ );
259
279
},
260
- py::arg (" stream" ) = py::none (),
280
+ // py::arg("stream") = py::none(),
281
+ // ... other args & their defaults
261
282
R"doc(
262
283
DLPack protocol for zero-copy tensor exchange.
263
284
See https://dmlc.github.io/dlpack/latest/ for details.
264
285
)doc"
265
286
)
266
287
.def (" __dlpack_device__" , [](Array4<T> const &a4) {
267
288
DLDevice device = dlpack::detect_device_from_pointer (a4.dataPtr ());
268
- return std::make_tuple (device.device_type , device.device_id );
289
+ return std::make_tuple (static_cast < int32_t >( device.device_type ) , device.device_id );
269
290
}, R"doc(
270
291
DLPack device info (device_type, device_id).
271
292
)doc" )
272
293
294
+
295
+
273
296
.def (" to_host" , [](Array4<T> const & a4) {
274
297
// py::tuple to std::vector
275
298
auto const a4i = pyAMReX::array_interface (a4);
0 commit comments