1
1
import math
2
2
import warnings
3
3
from functools import partial
4
- from typing import Dict , Literal , Optional , Tuple
4
+ from typing import Literal
5
5
6
6
import jax
7
7
import jax .numpy as jnp
8
- import numpy as np
9
8
from flax .struct import dataclass
10
9
from gymnasium import spaces
11
10
from gymnasium .vector import VectorEnv
12
11
from gymnasium .vector .utils import batch_space
13
12
from jax import Array
13
+ from numpy .typing import NDArray
14
14
15
15
from crazyflow .control .controller import MAX_THRUST , MIN_THRUST , Control
16
16
from crazyflow .sim .core import Sim
19
19
20
20
@dataclass
21
21
class RescaleParams :
22
- scale_factor : jnp . ndarray
23
- mean : jnp . ndarray
22
+ scale_factor : Array
23
+ mean : Array
24
24
25
25
26
26
CONTROL_RESCALE_PARAMS = {
@@ -35,6 +35,12 @@ class RescaleParams:
35
35
}
36
36
37
37
38
+ @partial (jax .jit , static_argnames = ["convert" ])
39
+ def maybe_to_numpy (data : Array , convert : bool ) -> NDArray | Array :
40
+ """Converts data to numpy array if convert is True."""
41
+ return jax .lax .cond (convert , lambda : jax .device_get (data ), lambda : data )
42
+
43
+
38
44
class CrazyflowBaseEnv (VectorEnv ):
39
45
"""JAX Gymnasium environment for Crazyflie simulation."""
40
46
@@ -103,7 +109,7 @@ def __init__(
103
109
)
104
110
self .observation_space = batch_space (self .single_observation_space , self .sim .n_worlds )
105
111
106
- def step (self , action : Array ) -> Tuple [Array , Array , Array , Array , Dict ]:
112
+ def step (self , action : Array ) -> tuple [Array , Array , Array , Array , dict ]:
107
113
assert self .action_space .contains (action ), f"{ action !r} ({ type (action )} ) invalid"
108
114
action = jnp .array (action , device = self .device ).reshape (
109
115
(self .sim .n_worlds , self .sim .n_drones , - 1 )
@@ -139,8 +145,8 @@ def step(self, action: Array) -> Tuple[Array, Array, Array, Array, Dict]:
139
145
return (
140
146
self ._get_obs (),
141
147
reward ,
142
- self . _maybe_to_numpy (terminated ),
143
- self . _maybe_to_numpy (truncated ),
148
+ maybe_to_numpy (terminated , self . return_datatype == "numpy" ),
149
+ maybe_to_numpy (truncated , self . return_datatype == "numpy" ),
144
150
{},
145
151
)
146
152
@@ -165,7 +171,7 @@ def _rescale_action(action: Array, control_type: str) -> Array:
165
171
return action * params .scale_factor + params .mean
166
172
167
173
def reset_all (
168
- self , * , seed : Optional [ int ] = None , options : Optional [ dict ] = None
174
+ self , * , seed : int | None = None , options : dict | None = None
169
175
) -> tuple [dict [str , Array ], dict ]:
170
176
super ().reset (seed = seed )
171
177
@@ -226,42 +232,37 @@ def _reward() -> None:
226
232
227
233
@staticmethod
228
234
@jax .jit
229
- def _terminated (dones : jax . Array , states : SimState , contacts : jax . Array ) -> jnp . ndarray :
235
+ def _terminated (dones : Array , states : SimState , contacts : Array ) -> Array :
230
236
contact = jnp .any (contacts , axis = 1 )
231
237
z_coords = states .pos [..., 2 ]
232
- below_ground = jnp .any (
233
- z_coords < - 0.1 , axis = 1
234
- ) # Should not be triggered due to collision checking
238
+ # Sanity check if we are below the ground. Should not be triggered due to collision checking
239
+ below_ground = jnp .any (z_coords < - 0.1 , axis = 1 )
235
240
terminated = jnp .logical_or (below_ground , contact ) # no termination condition
236
241
return jnp .where (dones , False , terminated )
237
242
238
243
@staticmethod
239
244
@jax .jit
240
245
def _truncated (
241
- dones : jax . Array , steps : jax . Array , max_episode_steps : jax . Array , n_substeps : jax . Array
242
- ) -> jnp . ndarray :
246
+ dones : Array , steps : Array , max_episode_steps : Array , n_substeps : Array
247
+ ) -> Array :
243
248
truncated = steps / n_substeps >= max_episode_steps
244
249
return jnp .where (dones , False , truncated )
245
250
246
251
def render (self ):
247
252
self .sim .render ()
248
253
249
- def _get_obs (self ) -> Dict [str , jnp . ndarray ]:
254
+ def _get_obs (self ) -> dict [str , Array ]:
250
255
obs = {
251
- state : self . _maybe_to_numpy (
256
+ state : maybe_to_numpy (
252
257
getattr (self .sim .states , state )[..., 2 ]
253
258
if state == "pos"
254
- else getattr (self .sim .states , state )
259
+ else getattr (self .sim .states , state ),
260
+ self .return_datatype == "numpy" ,
255
261
)
256
262
for state in self .states_to_include_in_obs
257
263
}
258
264
return obs
259
265
260
- def _maybe_to_numpy (self , data : Array ) -> np .ndarray :
261
- if self .return_datatype == "numpy" and not isinstance (data , np .ndarray ):
262
- return jax .device_get (data )
263
- return data
264
-
265
266
266
267
class CrazyflowEnvReachGoal (CrazyflowBaseEnv ):
267
268
"""JAX Gymnasium environment for Crazyflie simulation."""
@@ -284,7 +285,7 @@ def reward(self) -> Array:
284
285
285
286
@staticmethod
286
287
@jax .jit
287
- def _reward (terminated : jax . Array , states : SimState , goal : jax . Array ) -> jnp . ndarray :
288
+ def _reward (terminated : Array , states : SimState , goal : Array ) -> Array :
288
289
norm_distance = jnp .linalg .norm (states .pos - goal , axis = 2 )
289
290
reward = jnp .exp (- 2.0 * norm_distance )
290
291
return jnp .where (terminated , - 1.0 , reward )
@@ -302,7 +303,7 @@ def reset(self, mask: Array) -> None:
302
303
)
303
304
self .goal = self .goal .at [mask ].set (new_goals [mask ])
304
305
305
- def _get_obs (self ) -> Dict [str , jnp . ndarray ]:
306
+ def _get_obs (self ) -> dict [str , Array ]:
306
307
obs = super ()._get_obs ()
307
308
obs ["difference_to_goal" ] = [self .goal - self .sim .states .pos ]
308
309
return obs
@@ -329,7 +330,7 @@ def reward(self) -> Array:
329
330
330
331
@staticmethod
331
332
@jax .jit
332
- def _reward (terminated : jax . Array , states : SimState , target_vel : jax . Array ) -> jnp . ndarray :
333
+ def _reward (terminated : Array , states : SimState , target_vel : Array ) -> Array :
333
334
norm_distance = jnp .linalg .norm (states .vel - target_vel , axis = 2 )
334
335
reward = jnp .exp (- norm_distance )
335
336
return jnp .where (terminated , - 1.0 , reward )
@@ -347,7 +348,7 @@ def reset(self, mask: Array) -> None:
347
348
)
348
349
self .target_vel = self .target_vel .at [mask ].set (new_target_vel [mask ])
349
350
350
- def _get_obs (self ) -> Dict [str , jnp . ndarray ]:
351
+ def _get_obs (self ) -> dict [str , Array ]:
351
352
obs = super ()._get_obs ()
352
353
obs ["difference_to_target_vel" ] = [self .target_vel - self .sim .states .vel ]
353
354
return obs
0 commit comments