@@ -49,7 +49,7 @@ def __init__(
49
49
* ,
50
50
jax_random_key : int , # required for jax random number generator
51
51
num_envs : int = 1 , # required for VectorEnv
52
- max_episode_steps : int = 1000 ,
52
+ time_horizon_in_seconds : int = 10 ,
53
53
return_datatype : Literal ["numpy" , "jax" ] = "jax" ,
54
54
** kwargs : dict ,
55
55
):
@@ -58,7 +58,7 @@ def __init__(
58
58
Args:
59
59
jax_random_key: The random key for the jax random number generator.
60
60
num_envs: The number of environments to run in parallel.
61
- max_episode_steps : The maximum number of steps per episode .
61
+ time_horizon_in_seconds : The time horizon after which episodes are truncated .
62
62
return_datatype: The data type for returned arrays, either "numpy" or "jax". If "numpy",
63
63
the returned arrays will be numpy arrays on the CPU. If "jax", the returned arrays
64
64
will be jax arrays on the "device" specified for the simulation.
@@ -71,7 +71,9 @@ def __init__(
71
71
self .num_envs = num_envs
72
72
self .return_datatype = return_datatype
73
73
self .device = jax .devices (kwargs ["device" ])[0 ]
74
- self .max_episode_steps = jnp .array (max_episode_steps , dtype = jnp .int32 , device = self .device )
74
+ self .time_horizon_in_seconds = jnp .array (
75
+ time_horizon_in_seconds , dtype = jnp .int32 , device = self .device
76
+ )
75
77
76
78
self .sim = Sim (** kwargs )
77
79
@@ -224,7 +226,7 @@ def terminated(self) -> Array:
224
226
@property
225
227
def truncated (self ) -> Array :
226
228
return self ._truncated (
227
- self .prev_done , self .sim .steps , self .max_episode_steps , self .n_substeps
229
+ self .prev_done , self .sim .time , self .time_horizon_in_seconds , self .n_substeps
228
230
)
229
231
230
232
def _reward () -> None :
@@ -235,17 +237,18 @@ def _reward() -> None:
235
237
def _terminated (dones : Array , states : SimState , contacts : Array ) -> Array :
236
238
contact = jnp .any (contacts , axis = 1 )
237
239
z_coords = states .pos [..., 2 ]
238
- # Sanity check if we are below the ground. Should not be triggered due to collision checking
239
- below_ground = jnp .any (z_coords < - 0.1 , axis = 1 )
240
- terminated = jnp .logical_or (below_ground , contact ) # no termination condition
240
+ below_ground = jnp .any (
241
+ z_coords < - 0.1 , axis = 1
242
+ ) # Sanity check if we are below the ground. Should not be triggered due to collision checking
243
+ terminated = jnp .logical_or (below_ground , contact )
241
244
return jnp .where (dones , False , terminated )
242
245
243
246
@staticmethod
244
247
@jax .jit
245
248
def _truncated (
246
- dones : Array , steps : Array , max_episode_steps : Array , n_substeps : Array
249
+ dones : Array , time : Array , time_horizon_in_seconds : Array , n_substeps : Array
247
250
) -> Array :
248
- truncated = steps / n_substeps >= max_episode_steps
251
+ truncated = time >= time_horizon_in_seconds
249
252
return jnp .where (dones , False , truncated )
250
253
251
254
def render (self ):
@@ -352,3 +355,40 @@ def _get_obs(self) -> dict[str, Array]:
352
355
obs = super ()._get_obs ()
353
356
obs ["difference_to_target_vel" ] = [self .target_vel - self .sim .states .vel ]
354
357
return obs
358
+
359
+
360
+ class CrazyflowEnvLanding (CrazyflowBaseEnv ):
361
+ """JAX Gymnasium environment for Crazyflie simulation."""
362
+
363
+ def __init__ (self , ** kwargs : dict ):
364
+ assert kwargs ["n_drones" ] == 1 , "Currently only supported for one drone"
365
+
366
+ super ().__init__ (** kwargs )
367
+ self ._obs_size += 3 # difference to goal position
368
+ self .single_observation_space = spaces .Box (
369
+ - jnp .inf , jnp .inf , shape = (self ._obs_size ,), dtype = jnp .float32
370
+ )
371
+ self .observation_space = batch_space (self .single_observation_space , self .sim .n_worlds )
372
+
373
+ self .goal = jnp .zeros ((kwargs ["n_worlds" ], 3 ), dtype = jnp .float32 )
374
+ self .goal = self .goal .at [..., 2 ].set (0.1 ) # 10cm above ground
375
+
376
+ @property
377
+ def reward (self ) -> Array :
378
+ return self ._reward (self .terminated , self .sim .states , self .goal )
379
+
380
+ @staticmethod
381
+ @jax .jit
382
+ def _reward (terminated : Array , states : SimState , goal : Array ) -> Array :
383
+ norm_distance = jnp .linalg .norm (states .pos - goal , axis = 2 )
384
+ speed = jnp .linalg .norm (states .vel , axis = 2 )
385
+ reward = jnp .exp (- 2.0 * norm_distance ) * jnp .exp (- 2.0 * speed )
386
+ return jnp .where (terminated , - 1.0 , reward )
387
+
388
+ def reset (self , mask : Array ) -> None :
389
+ super ().reset (mask )
390
+
391
+ def _get_obs (self ) -> dict [str , Array ]:
392
+ obs = super ()._get_obs ()
393
+ obs ["difference_to_goal" ] = [self .goal - self .sim .states .pos ]
394
+ return obs
0 commit comments