Skip to content

Commit 96cca6c

Browse files
committed
follow strict tensorflow alignment requirements
1 parent 4647efc commit 96cca6c

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tests/test_ndarray.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,18 @@ NB_MODULE(test_ndarray_ext, m) {
312312
});
313313

314314
m.def("ret_tensorflow", []() {
315-
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
315+
struct alignas(256) Buf {
316+
float f[8];
317+
};
318+
Buf *buf = new Buf({ 1, 2, 3, 4, 5, 6, 7, 8 });
316319
size_t shape[2] = { 2, 4 };
317320

318-
nb::capsule deleter(f, [](void *data) noexcept {
321+
nb::capsule deleter(buf, [](void *data) noexcept {
319322
destruct_count++;
320-
delete[] (float *) data;
323+
delete[] (Buf *) data;
321324
});
322325

323-
return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(f, 2, shape,
326+
return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(buf->f, 2, shape,
324327
deleter);
325328
});
326329

0 commit comments

Comments
 (0)