10
10
from crazyflow .sim .physics import Physics
11
11
12
12
13
+ def available_backends () -> list [str ]:
14
+ """Return list of available JAX backends."""
15
+ backends = []
16
+ for backend in ["tpu" , "gpu" , "cpu" ]:
17
+ try :
18
+ jax .devices (backend )
19
+ except RuntimeError :
20
+ pass
21
+ else :
22
+ backends .append (backend )
23
+ return backends
24
+
25
+
26
+ def skip_unavailable_device (device : str ):
27
+ if device not in available_backends ():
28
+ pytest .skip (f"{ device } device not available" )
29
+
30
+
13
31
@pytest .mark .unit
14
32
@pytest .mark .parametrize ("physics" , Physics )
15
33
@pytest .mark .parametrize ("device" , ["gpu" , "cpu" ])
@@ -20,6 +38,7 @@ def test_sim_creation(
20
38
physics : Physics , device : str , control : Control , controller : Controller , n_worlds : int
21
39
):
22
40
n_drones = 1
41
+ skip_unavailable_device (device )
23
42
24
43
def create_sim () -> Sim :
25
44
return Sim (
@@ -58,6 +77,7 @@ def create_sim() -> Sim:
58
77
@pytest .mark .unit
59
78
@pytest .mark .parametrize ("device" , ["gpu" , "cpu" ])
60
79
def test_setup (device : str ):
80
+ skip_unavailable_device (device )
61
81
sim = Sim (n_worlds = 2 , n_drones = 3 , device = device )
62
82
sim .setup ()
63
83
@@ -69,6 +89,7 @@ def test_setup(device: str):
69
89
@pytest .mark .parametrize ("n_drones" , [1 , 3 ])
70
90
def test_reset (device : str , physics : Physics , n_worlds : int , n_drones : int ):
71
91
"""Test that reset without mask resets all worlds to default state."""
92
+ skip_unavailable_device (device )
72
93
sim = Sim (n_worlds = n_worlds , n_drones = n_drones , physics = physics , device = device )
73
94
if physics == Physics .mujoco :
74
95
return # MuJoCo is not yet supported. TODO: Enable once supported
@@ -113,6 +134,7 @@ def test_reset(device: str, physics: Physics, n_worlds: int, n_drones: int):
113
134
@pytest .mark .parametrize ("physics" , Physics )
114
135
def test_reset_masked (device : str , physics : Physics ):
115
136
"""Test that reset with mask only resets specified worlds."""
137
+ skip_unavailable_device (device )
116
138
sim = Sim (n_worlds = 2 , n_drones = 1 , physics = physics , device = device )
117
139
118
140
# Modify states
@@ -169,6 +191,7 @@ def test_sim_step(
169
191
controller : Controller ,
170
192
device : str ,
171
193
):
194
+ skip_unavailable_device (device )
172
195
if n_drones * n_worlds > 1 and controller == Controller .pycffirmware :
173
196
return # PyCFFirmware does not support multiple drones
174
197
sim = Sim (
@@ -218,6 +241,7 @@ def test_sim_control(control: Control, control_freq: int):
218
241
@pytest .mark .unit
219
242
@pytest .mark .parametrize ("device" , ["gpu" , "cpu" ])
220
243
def test_sim_state_control (device : str ):
244
+ skip_unavailable_device (device )
221
245
sim = Sim (n_worlds = 2 , n_drones = 3 , control = Control .state , device = device )
222
246
cmd = np .random .rand (sim .n_worlds , sim .n_drones , 13 )
223
247
sim .state_control (cmd )
@@ -228,6 +252,7 @@ def test_sim_state_control(device: str):
228
252
@pytest .mark .unit
229
253
@pytest .mark .parametrize ("device" , ["gpu" , "cpu" ])
230
254
def test_sim_attitude_control (device : str ):
255
+ skip_unavailable_device (device )
231
256
sim = Sim (n_worlds = 2 , n_drones = 3 , control = Control .attitude , device = device )
232
257
cmd = np .random .rand (sim .n_worlds , sim .n_drones , 4 )
233
258
sim .attitude_control (cmd )
@@ -238,6 +263,7 @@ def test_sim_attitude_control(device: str):
238
263
@pytest .mark .parametrize ("device" , ["gpu" , "cpu" ])
239
264
@pytest .mark .render
240
265
def test_render (device : str ):
266
+ skip_unavailable_device (device )
241
267
sim = Sim (device = device )
242
268
sim .render ()
243
269
sim .viewer .close ()
@@ -246,6 +272,7 @@ def test_render(device: str):
246
272
@pytest .mark .unit
247
273
@pytest .mark .parametrize ("device" , ["gpu" , "cpu" ])
248
274
def test_device (device : str ):
275
+ skip_unavailable_device (device )
249
276
sim = Sim (n_worlds = 2 , physics = Physics .sys_id , device = device )
250
277
sim .step ()
251
278
assert sim .states .pos .device == jax .devices (device )[0 ]
@@ -257,6 +284,7 @@ def test_device(device: str):
257
284
@pytest .mark .parametrize ("n_worlds" , [1 , 2 ])
258
285
@pytest .mark .parametrize ("n_drones" , [1 , 3 ])
259
286
def test_shape_consistency (device : str , n_drones : int , n_worlds : int ):
287
+ skip_unavailable_device (device )
260
288
sim = Sim (n_worlds = n_worlds , n_drones = n_drones , physics = Physics .sys_id , device = device )
261
289
qpos_shape , qvel_shape = sim ._mjx_data .qpos .shape , sim ._mjx_data .qvel .shape
262
290
sim .step ()
0 commit comments