@@ -169,14 +169,15 @@ def reset(
169
169
Observation and info.
170
170
"""
171
171
if not self .config .env .random_resets :
172
+ self .np_random = np .random .default_rng (seed = self .config .env .seed )
172
173
self .sim .seed (self .config .env .seed )
173
174
if seed is not None :
175
+ self .np_random = np .random .default_rng (seed = self .config .env .seed )
174
176
self .sim .seed (seed )
175
177
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
176
178
# the sim.reset_hook function, so we don't need to explicitly do it here
177
179
self .sim .reset ()
178
180
179
- # TODO: Add disturbances
180
181
self .target_gate = 0
181
182
self ._steps = 0
182
183
self ._last_drone_pos = self .sim .data .states .pos [0 , 0 ]
@@ -199,9 +200,13 @@ def step(
199
200
action: Full-state command [x, y, z, vx, vy, vz, ax, ay, az, yaw, rrate, prate, yrate]
200
201
to follow.
201
202
"""
202
- # TODO: Add action noise
203
203
assert action .shape == self .action_space .shape , f"Invalid action shape: { action .shape } "
204
- self .sim .state_control (action .reshape ((1 , 1 , 13 )))
204
+ action = action .reshape ((1 , 1 , 13 ))
205
+ if "action" in self .disturbances :
206
+ key , subkey = jax .random .split (self .sim .data .core .rng_key )
207
+ action += self .disturbances ["action" ](subkey , (1 , 1 , 13 ))
208
+ self .sim .data = self .sim .data .replace (core = self .sim .data .core .replace (rng_key = key ))
209
+ self .sim .state_control (action )
205
210
self .sim .step (self .sim .freq // self .config .env .freq )
206
211
self .target_gate += self .gate_passed ()
207
212
if self .target_gate == self .n_gates :
@@ -308,7 +313,6 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
308
313
309
314
def load_disturbances (self , disturbances : dict | None = None ) -> dict :
310
315
"""Load the disturbances from the config."""
311
- # TODO: Add jax disturbances for the simulator dynamics
312
316
if disturbances is None : # Default: no passive disturbances.
313
317
return {}
314
318
return {mode : self .load_random_fn (spec ) for mode , spec in disturbances .items ()}
@@ -344,6 +348,8 @@ def setup_sim(self):
344
348
states = self .sim .data .states .replace (pos = pos , quat = quat , vel = vel , rpy_rates = rpy_rates )
345
349
self .sim .data = self .sim .data .replace (states = states )
346
350
self .sim .reset_hook = build_reset_hook (self .randomization )
351
+ if "dynamics" in self .disturbances :
352
+ self .sim .disturbance_fn = build_dynamics_disturbance_fn (self .disturbances ["dynamics" ])
347
353
self .sim .build (mjx = False , data = False ) # Save the reset state and rebuild the reset function
348
354
349
355
def _load_track_into_sim (self , gates : dict , obstacles : dict ):
@@ -406,6 +412,20 @@ def reset_hook(data: SimData, mask: Array) -> SimData:
406
412
return reset_hook
407
413
408
414
415
+ def build_dynamics_disturbance_fn (
416
+ fn : Callable [[jax .random .PRNGKey , tuple [int ]], jax .Array ],
417
+ ) -> Callable [[SimData ], SimData ]:
418
+ """Build the dynamics disturbance function for the simulation."""
419
+
420
+ def dynamics_disturbance (data : SimData ) -> SimData :
421
+ key , subkey = jax .random .split (data .core .rng_key )
422
+ states = data .states
423
+ states = states .replace (force = states .force + fn (subkey , states .force .shape )) # World frame
424
+ return data .replace (states = states , core = data .core .replace (rng_key = key ))
425
+
426
+ return dynamics_disturbance
427
+
428
+
409
429
class DroneRacingThrustEnv (DroneRacingEnv ):
410
430
"""Drone racing environment with a collective thrust attitude command interface.
411
431
@@ -433,7 +453,10 @@ def step(
433
453
action: Thrust command [thrust, roll, pitch, yaw].
434
454
"""
435
455
assert action .shape == self .action_space .shape , f"Invalid action shape: { action .shape } "
436
- # TODO: Add action noise
456
+ if "action" in self .disturbances :
457
+ key , subkey = jax .random .split (self .sim .data .core .rng_key )
458
+ action += self .disturbances ["action" ](subkey , (1 , 1 , 4 ))
459
+ self .sim .data = self .sim .data .replace (core = self .sim .data .core .replace (rng_key = key ))
437
460
self .sim .attitude_control (action .reshape ((1 , 1 , 4 )).astype (np .float32 ))
438
461
self .sim .step (self .sim .freq // self .config .env .freq )
439
462
self .target_gate += self .gate_passed ()
0 commit comments