Skip to content

Commit c1be430

Browse files
authored
Add ndarray tests return_jax and return_tensorflow. (#728)
1 parent 3925f57 commit c1be430

File tree

3 files changed

+105
-44
lines changed

3 files changed

+105
-44
lines changed

tests/test_ndarray.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,32 @@ NB_MODULE(test_ndarray_ext, m) {
298298
deleter);
299299
});
300300

301+
m.def("ret_jax", []() {
302+
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
303+
size_t shape[2] = { 2, 4 };
304+
305+
nb::capsule deleter(f, [](void *data) noexcept {
306+
destruct_count++;
307+
delete[] (float *) data;
308+
});
309+
310+
return nb::ndarray<nb::jax, float, nb::shape<2, 4>>(f, 2, shape,
311+
deleter);
312+
});
313+
314+
m.def("ret_tensorflow", []() {
315+
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
316+
size_t shape[2] = { 2, 4 };
317+
318+
nb::capsule deleter(f, [](void *data) noexcept {
319+
destruct_count++;
320+
delete[] (float *) data;
321+
});
322+
323+
return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(f, 2, shape,
324+
deleter);
325+
});
326+
301327
m.def("ret_array_scalar", []() {
302328
float* f = new float[1] { 1 };
303329
size_t shape[1] = {};

0 commit comments

Comments
 (0)