Skip to content

Commit 7d1bc1b

Browse files
Extract JAX and TF from test_ndarray.cpp. Put stub tests behind flag.
This commit extracts the functions involving JAX and TensorFlow types from the module in `test_ndarray.cpp` into their own respective modules and sets up the boilerplate code to turn them into independent tests. This allows to put the corresponding stub generation tests behind a flag and deactivate them by default. Signed-off-by: Ingo Müller <[email protected]>
1 parent d03d5f7 commit 7d1bc1b

12 files changed

+87
-44
lines changed

tests/CMakeLists.txt

+8-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ set(TEST_NAMES
7272
enum
7373
eval
7474
ndarray
75+
jax
76+
tensorflow
7577
exception
7678
make_iterator
7779
typing
@@ -90,7 +92,7 @@ endforeach()
9092

9193
target_sources(test_intrusive_ext PRIVATE test_intrusive_impl.cpp)
9294

93-
foreach (NAME functions classes ndarray stl enum typing make_iterator)
95+
foreach (NAME functions classes ndarray jax tensorflow stl enum typing make_iterator)
9496
if (NAME STREQUAL typing)
9597
set(EXTRA
9698
MARKER_FILE py.typed
@@ -138,6 +140,7 @@ target_link_libraries(test_inter_module_2_ext PRIVATE inter_module)
138140

139141
set(TEST_FILES
140142
common.py
143+
conftest.py
141144
test_callbacks.py
142145
test_classes.py
143146
test_eigen.py
@@ -154,6 +157,8 @@ set(TEST_FILES
154157
test_stl_bind_vector.py
155158
test_chrono.py
156159
test_ndarray.py
160+
test_jax.py
161+
test_tensorflow.py
157162
test_stubs.py
158163
test_typing.py
159164
test_thread.py
@@ -163,6 +168,8 @@ set(TEST_FILES
163168
test_functions_ext.pyi.ref
164169
test_make_iterator_ext.pyi.ref
165170
test_ndarray_ext.pyi.ref
171+
test_jax_ext.pyi.ref
172+
test_tensorflow_ext.pyi.ref
166173
test_stl_ext.pyi.ref
167174
test_enum_ext.pyi.ref
168175
test_typing_ext.pyi.ref

tests/conftest.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def pytest_addoption(parser):
2+
parser.addoption('--enable-slow-tests',
3+
action='store_true',
4+
dest="enable-slow-tests",
5+
default=False,
6+
help="enable longrundecorated tests")

tests/test_jax.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <nanobind/nanobind.h>
2+
#include <nanobind/ndarray.h>
3+
4+
namespace nb = nanobind;
5+
6+
int destruct_count = 0;
7+
8+
NB_MODULE(test_jax_ext, m) {
9+
m.def("destruct_count", []() { return destruct_count; });
10+
m.def("ret_jax", []() {
11+
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
12+
size_t shape[2] = { 2, 4 };
13+
14+
nb::capsule deleter(f, [](void *data) noexcept {
15+
destruct_count++;
16+
delete[] (float *) data;
17+
});
18+
19+
return nb::ndarray<nb::jax, float, nb::shape<2, 4>>(f, 2, shape,
20+
deleter);
21+
});
22+
}

tests/test_jax.py

Whitespace-only changes.

tests/test_jax_ext.pyi.ref

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import jaxlib.xla_extension
2+
3+
4+
def ret_jax() -> jaxlib.xla_extension.DeviceArray[dtype=float32, shape=(2, 4)]: ...

tests/test_ndarray.cpp

-29
Original file line numberDiff line numberDiff line change
@@ -299,35 +299,6 @@ NB_MODULE(test_ndarray_ext, m) {
299299
deleter);
300300
});
301301

302-
m.def("ret_jax", []() {
303-
float *f = new float[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
304-
size_t shape[2] = { 2, 4 };
305-
306-
nb::capsule deleter(f, [](void *data) noexcept {
307-
destruct_count++;
308-
delete[] (float *) data;
309-
});
310-
311-
return nb::ndarray<nb::jax, float, nb::shape<2, 4>>(f, 2, shape,
312-
deleter);
313-
});
314-
315-
m.def("ret_tensorflow", []() {
316-
struct alignas(256) Buf {
317-
float f[8];
318-
};
319-
Buf *buf = new Buf({ 1, 2, 3, 4, 5, 6, 7, 8 });
320-
size_t shape[2] = { 2, 4 };
321-
322-
nb::capsule deleter(buf, [](void *data) noexcept {
323-
destruct_count++;
324-
delete (Buf *) data;
325-
});
326-
327-
return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(buf->f, 2, shape,
328-
deleter);
329-
});
330-
331302
m.def("ret_array_scalar", []() {
332303
float* f = new float[1] { 1 };
333304
size_t shape[1] = {};

tests/test_ndarray.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import test_ndarray_ext as t
2+
import test_jax_ext as tj
3+
import test_tensorflow_ext as tt
24
import pytest
35
import warnings
46
import importlib
@@ -377,26 +379,26 @@ def test18_return_pytorch():
377379
@needs_jax
378380
def test19_return_jax():
379381
collect()
380-
dc = t.destruct_count()
381-
x = t.ret_jax()
382+
dc = tj.destruct_count()
383+
x = tj.ret_jax()
382384
assert x.shape == (2, 4)
383385
assert jnp.all(x == jnp.array([[1,2,3,4], [5,6,7,8]], dtype=jnp.float32))
384386
del x
385387
collect()
386-
assert t.destruct_count() - dc == 1
388+
assert tj.destruct_count() - dc == 1
387389

388390

389391
@needs_tensorflow
390392
def test20_return_tensorflow():
391393
collect()
392-
dc = t.destruct_count()
393-
x = t.ret_tensorflow()
394+
dc = tt.destruct_count()
395+
x = tt.ret_tensorflow()
394396
assert x.get_shape().as_list() == [2, 4]
395397
assert tf.math.reduce_all(
396398
x == tf.constant([[1,2,3,4], [5,6,7,8]], dtype=tf.float32))
397399
del x
398400
collect()
399-
assert t.destruct_count() - dc == 1
401+
assert tt.destruct_count() - dc == 1
400402

401403

402404
@needs_numpy

tests/test_ndarray_ext.pyi.ref

-6
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from typing import Annotated, overload
22

3-
import jaxlib.xla_extension
43
from numpy.typing import ArrayLike
5-
import tensorflow.python.framework.ops
64

75

86
class Cls:
@@ -175,8 +173,6 @@ def ret_infer_c() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), or
175173

176174
def ret_infer_f() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), order='F')]: ...
177175

178-
def ret_jax() -> jaxlib.xla_extension.DeviceArray[dtype=float32, shape=(2, 4)]: ...
179-
180176
def ret_numpy() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...
181177

182178
def ret_numpy_const() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4), writable=False)]: ...
@@ -189,8 +185,6 @@ def ret_numpy_half() -> Annotated[ArrayLike, dict(dtype='float16', shape=(2, 4))
189185

190186
def ret_pytorch() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...
191187

192-
def ret_tensorflow() -> tensorflow.python.framework.ops.EagerTensor[dtype=float32, shape=(2, 4)]: ...
193-
194188
def return_dlpack() -> Annotated[ArrayLike, dict(dtype='float32', shape=(2, 4))]: ...
195189

196190
@overload

tests/test_stubs.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,23 @@ def remove_platform_dependent(s):
2424
i += 1
2525
return s2
2626

27+
2728
ref_paths = list(pathlib.Path(__file__).parent.glob('*.pyi.ref'))
29+
ref_path_ids = [
30+
p.name.removeprefix('test_').removesuffix('.pyi.ref') for p in ref_paths
31+
]
2832
assert len(ref_paths) > 0, "Stub reference files not found!"
2933

3034
@skip_on_unsupported
31-
@pytest.mark.parametrize('p_ref', ref_paths)
32-
def test01_check_stub_refs(p_ref):
35+
@pytest.mark.parametrize('p_ref', ref_paths, ids=ref_path_ids)
36+
def test01_check_stub_refs(p_ref, request):
3337
"""
3438
Check that generated stub files match reference input
3539
"""
40+
if not request.config.getoption('enable-slow-tests') and any(
41+
(x in p_ref.name for x in ['jax', 'tensorflow'])):
42+
pytest.skip("skipping because slow tests are not enabled")
43+
3644
p_in = p_ref.with_suffix('')
3745
with open(p_ref, 'r') as f:
3846
s_ref = f.read().split('\n')

tests/test_tensorflow.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <nanobind/nanobind.h>
2+
#include <nanobind/ndarray.h>
3+
namespace nb = nanobind;
4+
5+
int destruct_count = 0;
6+
7+
8+
NB_MODULE(test_tensorflow_ext, m) {
9+
m.def("destruct_count", []() { return destruct_count; });
10+
m.def("ret_tensorflow", []() {
11+
struct alignas(256) Buf {
12+
float f[8];
13+
};
14+
Buf *buf = new Buf({ 1, 2, 3, 4, 5, 6, 7, 8 });
15+
size_t shape[2] = { 2, 4 };
16+
17+
nb::capsule deleter(buf, [](void *data) noexcept {
18+
destruct_count++;
19+
delete (Buf *) data;
20+
});
21+
22+
return nb::ndarray<nb::tensorflow, float, nb::shape<2, 4>>(buf->f, 2, shape,
23+
deleter);
24+
});
25+
}

tests/test_tensorflow.py

Whitespace-only changes.

tests/test_tensorflow_ext.pyi.ref

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import tensorflow.python.framework.ops
2+
3+
4+
def ret_tensorflow() -> tensorflow.python.framework.ops.EagerTensor[dtype=float32, shape=(2, 4)]: ...

0 commit comments

Comments
 (0)