Skip to content

Commit 92e8710

Browse files
committed
Skip GPU tests if not available
1 parent 850f494 commit 92e8710

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/unit/test_sim.py

+28
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010
from crazyflow.sim.physics import Physics
1111

1212

13+
def available_backends() -> list[str]:
14+
"""Return list of available JAX backends."""
15+
backends = []
16+
for backend in ["tpu", "gpu", "cpu"]:
17+
try:
18+
jax.devices(backend)
19+
except RuntimeError:
20+
pass
21+
else:
22+
backends.append(backend)
23+
return backends
24+
25+
26+
def skip_unavailable_device(device: str):
27+
if device not in available_backends():
28+
pytest.skip(f"{device} device not available")
29+
30+
1331
@pytest.mark.unit
1432
@pytest.mark.parametrize("physics", Physics)
1533
@pytest.mark.parametrize("device", ["gpu", "cpu"])
@@ -20,6 +38,7 @@ def test_sim_creation(
2038
physics: Physics, device: str, control: Control, controller: Controller, n_worlds: int
2139
):
2240
n_drones = 1
41+
skip_unavailable_device(device)
2342

2443
def create_sim() -> Sim:
2544
return Sim(
@@ -58,6 +77,7 @@ def create_sim() -> Sim:
5877
@pytest.mark.unit
5978
@pytest.mark.parametrize("device", ["gpu", "cpu"])
6079
def test_setup(device: str):
80+
skip_unavailable_device(device)
6181
sim = Sim(n_worlds=2, n_drones=3, device=device)
6282
sim.setup()
6383

@@ -69,6 +89,7 @@ def test_setup(device: str):
6989
@pytest.mark.parametrize("n_drones", [1, 3])
7090
def test_reset(device: str, physics: Physics, n_worlds: int, n_drones: int):
7191
"""Test that reset without mask resets all worlds to default state."""
92+
skip_unavailable_device(device)
7293
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, device=device)
7394
if physics == Physics.mujoco:
7495
return # MuJoCo is not yet supported. TODO: Enable once supported
@@ -113,6 +134,7 @@ def test_reset(device: str, physics: Physics, n_worlds: int, n_drones: int):
113134
@pytest.mark.parametrize("physics", Physics)
114135
def test_reset_masked(device: str, physics: Physics):
115136
"""Test that reset with mask only resets specified worlds."""
137+
skip_unavailable_device(device)
116138
sim = Sim(n_worlds=2, n_drones=1, physics=physics, device=device)
117139

118140
# Modify states
@@ -169,6 +191,7 @@ def test_sim_step(
169191
controller: Controller,
170192
device: str,
171193
):
194+
skip_unavailable_device(device)
172195
if n_drones * n_worlds > 1 and controller == Controller.pycffirmware:
173196
return # PyCFFirmware does not support multiple drones
174197
sim = Sim(
@@ -218,6 +241,7 @@ def test_sim_control(control: Control, control_freq: int):
218241
@pytest.mark.unit
219242
@pytest.mark.parametrize("device", ["gpu", "cpu"])
220243
def test_sim_state_control(device: str):
244+
skip_unavailable_device(device)
221245
sim = Sim(n_worlds=2, n_drones=3, control=Control.state, device=device)
222246
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 13)
223247
sim.state_control(cmd)
@@ -228,6 +252,7 @@ def test_sim_state_control(device: str):
228252
@pytest.mark.unit
229253
@pytest.mark.parametrize("device", ["gpu", "cpu"])
230254
def test_sim_attitude_control(device: str):
255+
skip_unavailable_device(device)
231256
sim = Sim(n_worlds=2, n_drones=3, control=Control.attitude, device=device)
232257
cmd = np.random.rand(sim.n_worlds, sim.n_drones, 4)
233258
sim.attitude_control(cmd)
@@ -238,6 +263,7 @@ def test_sim_attitude_control(device: str):
238263
@pytest.mark.parametrize("device", ["gpu", "cpu"])
239264
@pytest.mark.render
240265
def test_render(device: str):
266+
skip_unavailable_device(device)
241267
sim = Sim(device=device)
242268
sim.render()
243269
sim.viewer.close()
@@ -246,6 +272,7 @@ def test_render(device: str):
246272
@pytest.mark.unit
247273
@pytest.mark.parametrize("device", ["gpu", "cpu"])
248274
def test_device(device: str):
275+
skip_unavailable_device(device)
249276
sim = Sim(n_worlds=2, physics=Physics.sys_id, device=device)
250277
sim.step()
251278
assert sim.states.pos.device == jax.devices(device)[0]
@@ -257,6 +284,7 @@ def test_device(device: str):
257284
@pytest.mark.parametrize("n_worlds", [1, 2])
258285
@pytest.mark.parametrize("n_drones", [1, 3])
259286
def test_shape_consistency(device: str, n_drones: int, n_worlds: int):
287+
skip_unavailable_device(device)
260288
sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.sys_id, device=device)
261289
qpos_shape, qvel_shape = sim._mjx_data.qpos.shape, sim._mjx_data.qvel.shape
262290
sim.step()

0 commit comments

Comments
 (0)