Skip to content

Commit 833eaff

Browse files
committed
More Cleanup
DLPack 1.1, e.g., in NumPy 2.1+ Tests do not yet pass.
1 parent 4539006 commit 833eaff

File tree

3 files changed

+63
-67
lines changed

3 files changed

+63
-67
lines changed

src/Base/Array4.H

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -225,51 +225,74 @@ namespace pyAMReX
225225
*/
226226

227227

228-
// DLPack protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
228+
// DLPack v1.1 protocol (CPU, NVIDIA GPU, AMD GPU, Intel GPU, etc.)
229229
// https://dmlc.github.io/dlpack/latest/
230230
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
231231
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol
232-
.def("__dlpack__", [](Array4<T> const &a4, [[maybe_unused]] py::handle stream = py::none()) {
232+
.def("__dlpack__", [](
233+
Array4<T> const &a4
234+
/* TODO:
235+
[[maybe_unused]] py::handle stream,
236+
[[maybe_unused]] std::tuple<int, int> max_version,
237+
[[maybe_unused]] std::tuple<DLDeviceType, int32_t> dl_device,
238+
[[maybe_unused]] bool copy
239+
*/
240+
)
241+
{
233242
// Allocate shape/strides arrays
234243
constexpr int ndim = 4;
235244
auto const len = length(a4);
236-
auto *shape = new int64_t[ndim]{a4.nComp(), len.z, len.y, len.x};
237-
auto *strides = new int64_t[ndim]{a4.nstride, a4.kstride, a4.jstride, 1};
238-
239-
// Construct DLTensor
240-
auto *dl_tensor = new DLManagedTensor;
241-
dl_tensor->dl_tensor.data = const_cast<void*>(static_cast<const void*>(a4.dataPtr()));
242-
dl_tensor->dl_tensor.device = dlpack::detect_device_from_pointer(a4.dataPtr());
243-
dl_tensor->dl_tensor.ndim = ndim;
244-
dl_tensor->dl_tensor.dtype = dlpack::get_dlpack_dtype<T>();
245-
dl_tensor->dl_tensor.shape = shape;
246-
dl_tensor->dl_tensor.strides = strides;
247-
dl_tensor->dl_tensor.byte_offset = 0;
248-
dl_tensor->manager_ctx = nullptr;
249-
dl_tensor->deleter = [](DLManagedTensor *self) {
245+
246+
// Construct DLManagedTensorVersioned (DLPack 1.1 standard)
247+
auto *dl_mgt_tensor = new DLManagedTensorVersioned;
248+
dl_mgt_tensor->version = DLPackVersion{};
249+
dl_mgt_tensor->flags = 0; // No special flags
250+
dl_mgt_tensor->dl_tensor.data = const_cast<void*>(static_cast<const void*>(a4.dataPtr()));
251+
dl_mgt_tensor->dl_tensor.device = dlpack::detect_device_from_pointer(a4.dataPtr());
252+
dl_mgt_tensor->dl_tensor.ndim = ndim;
253+
dl_mgt_tensor->dl_tensor.dtype = dlpack::get_dlpack_dtype<T>();
254+
dl_mgt_tensor->dl_tensor.shape = new int64_t[ndim]{a4.nComp(), len.z, len.y, len.x};
255+
dl_mgt_tensor->dl_tensor.strides = new int64_t[ndim]{a4.nstride, a4.kstride, a4.jstride, 1};
256+
dl_mgt_tensor->dl_tensor.byte_offset = 0;
257+
dl_mgt_tensor->manager_ctx = nullptr; // TODO: we can increase/decrease the Python ref counter of the producer here
258+
dl_mgt_tensor->deleter = [](DLManagedTensorVersioned *self) {
250259
delete[] self->dl_tensor.shape;
251260
delete[] self->dl_tensor.strides;
252261
delete self;
253262
};
254263
// Return as Python capsule
255-
return py::capsule(dl_tensor, "dltensor", [](void* ptr) {
256-
auto* tensor = static_cast<DLManagedTensor*>(ptr);
257-
tensor->deleter(tensor);
258-
});
264+
return py::capsule(
265+
dl_mgt_tensor,
266+
"dltensor_versioned",
267+
/*[](void* ptr) {
268+
auto* tensor = static_cast<DLManagedTensorVersioned*>(ptr);
269+
tensor->deleter(tensor);
270+
}*/
271+
[](PyObject *capsule)
272+
{
273+
auto *p = static_cast<DLManagedTensorVersioned*>(
274+
PyCapsule_GetPointer(capsule, "dltensor_versioned"));
275+
if (p && p->deleter)
276+
p->deleter(p);
277+
}
278+
);
259279
},
260-
py::arg("stream") = py::none(),
280+
//py::arg("stream") = py::none(),
281+
// ... other args & their defaults
261282
R"doc(
262283
DLPack protocol for zero-copy tensor exchange.
263284
See https://dmlc.github.io/dlpack/latest/ for details.
264285
)doc"
265286
)
266287
.def("__dlpack_device__", [](Array4<T> const &a4) {
267288
DLDevice device = dlpack::detect_device_from_pointer(a4.dataPtr());
268-
return std::make_tuple(device.device_type, device.device_id);
289+
return std::make_tuple(static_cast<int32_t>(device.device_type), device.device_id);
269290
}, R"doc(
270291
DLPack device info (device_type, device_id).
271292
)doc")
272293

294+
295+
273296
.def("to_host", [](Array4<T> const & a4) {
274297
// py::tuple to std::vector
275298
auto const a4i = pyAMReX::array_interface(a4);

src/dlpack.h

Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ extern "C" {
4040
*/
4141
typedef struct {
4242
/*! \brief DLPack major version. */
43-
uint32_t major;
43+
uint32_t major = 1;
4444
/*! \brief DLPack minor version. */
45-
uint32_t minor;
45+
uint32_t minor = 1;
4646
} DLPackVersion;
4747

4848
/*!
@@ -238,37 +238,7 @@ typedef struct {
238238
uint64_t byte_offset;
239239
} DLTensor;
240240

241-
/*!
242-
* \brief C Tensor object, manage memory of DLTensor. This data structure is
243-
* intended to facilitate the borrowing of DLTensor by another framework. It is
244-
* not meant to transfer the tensor. When the borrowing framework doesn't need
245-
* the tensor, it should call the deleter to notify the host that the resource
246-
* is no longer needed.
247-
*
248-
* \note This data structure is used as Legacy DLManagedTensor
249-
* in DLPack exchange and is deprecated after DLPack v0.8
250-
* Use DLManagedTensorVersioned instead.
251-
* This data structure may get renamed or deleted in future versions.
252-
*
253-
* \sa DLManagedTensorVersioned
254-
*/
255-
typedef struct DLManagedTensor {
256-
/*! \brief DLTensor which is being memory managed */
257-
DLTensor dl_tensor;
258-
/*! \brief the context of the original host framework of DLManagedTensor in
259-
* which DLManagedTensor is used in the framework. It can also be NULL.
260-
*/
261-
void * manager_ctx;
262-
/*!
263-
* \brief Destructor - this should be called
264-
* to destruct the manager_ctx which backs the DLManagedTensor. It can be
265-
* NULL if there is no way for the caller to provide a reasonable destructor.
266-
* The destructor deletes the argument self as well.
267-
*/
268-
void (*deleter)(struct DLManagedTensor * self);
269-
} DLManagedTensor;
270-
271-
// bit masks used in in the DLManagedTensorVersioned
241+
// bit masks used in the DLManagedTensorVersioned
272242

273243
/*! \brief bit mask to indicate that the tensor is read only. */
274244
#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
@@ -344,34 +314,35 @@ namespace pyAMReX::dlpack
344314
AMREX_INLINE
345315
DLDataType get_dlpack_dtype ()
346316
{
317+
using V = std::decay_t<T>;
347318
DLDataType dtype{};
348319

349-
if constexpr (std::is_same_v<T, float>) {
320+
if constexpr (std::is_same_v<V, float>) {
350321
dtype.code = kDLFloat;
351322
dtype.bits = 32;
352323
dtype.lanes = 1;
353324
}
354-
else if constexpr (std::is_same_v<T, double>) {
325+
else if constexpr (std::is_same_v<V, double>) {
355326
dtype.code = kDLFloat;
356327
dtype.bits = 64;
357328
dtype.lanes = 1;
358329
}
359-
else if constexpr (std::is_same_v<T, int32_t>) {
330+
else if constexpr (std::is_same_v<V, int32_t>) {
360331
dtype.code = kDLInt;
361332
dtype.bits = 32;
362333
dtype.lanes = 1;
363334
}
364-
else if constexpr (std::is_same_v<T, int64_t>) {
335+
else if constexpr (std::is_same_v<V, int64_t>) {
365336
dtype.code = kDLInt;
366337
dtype.bits = 64;
367338
dtype.lanes = 1;
368339
}
369-
else if constexpr (std::is_same_v<T, uint32_t>) {
340+
else if constexpr (std::is_same_v<V, uint32_t>) {
370341
dtype.code = kDLUInt;
371342
dtype.bits = 32;
372343
dtype.lanes = 1;
373344
}
374-
else if constexpr (std::is_same_v<T, uint64_t>) {
345+
else if constexpr (std::is_same_v<V, uint64_t>) {
375346
dtype.code = kDLUInt;
376347
dtype.bits = 64;
377348
dtype.lanes = 1;

tests/test_array4.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def test_array4():
3131
)
3232
print(f"\nx: {x.__array_interface__} {x.dtype}")
3333
arr = amr.Array4_double(x)
34-
print(f"arr: {arr.__array_interface__}")
34+
print(f"arr: DLPack device info: {arr.__dlpack_device__()}")
35+
# print(f"arr: DLPack: {arr.__dlpack__()}")
36+
print(f"x.shape: {x.shape}")
3537
print(arr)
3638
assert arr.nComp == 1
3739

@@ -44,16 +46,16 @@ def test_array4():
4446
assert arr[0, 0, 0] == 1
4547
assert arr[3, 2, 1] == 1
4648

47-
# copy to numpy
48-
c_arr2np = np.array(arr, copy=True) # segfaults on Windows
49+
# copy to numpy using DLPack
50+
c_arr2np = np.from_dlpack(arr)
4951
assert c_arr2np.ndim == 4
5052
assert c_arr2np.dtype == np.dtype("double")
5153
print(f"c_arr2np: {c_arr2np.__array_interface__}")
5254
np.testing.assert_array_equal(x, c_arr2np[0, :, :, :])
5355
assert c_arr2np[0, 1, 1, 1] == 42
5456

55-
# view to numpy
56-
v_arr2np = np.array(arr, copy=False)
57+
# view to numpy using DLPack
58+
v_arr2np = np.from_dlpack(arr)
5759
assert c_arr2np.ndim == 4
5860
assert v_arr2np.dtype == np.dtype("double")
5961
np.testing.assert_array_equal(x, v_arr2np[0, :, :, :])
@@ -65,7 +67,7 @@ def test_array4():
6567

6668
# copy array4 (view)
6769
c_arr = amr.Array4_double(arr)
68-
v_carr2np = np.array(c_arr, copy=False)
70+
v_carr2np = np.from_dlpack(c_arr)
6971
x[1, 1, 1] = 44
7072
assert v_carr2np[0, 1, 1, 1] == 44
7173

0 commit comments

Comments
 (0)