@@ -48,13 +48,13 @@ class MultiDroneRacingEnv(gymnasium.Env):
48
48
"""A Gymnasium environment for drone racing simulations.
49
49
50
50
This environment simulates a drone racing scenario where a single drone navigates through a
51
- series of gates in a predefined track. It uses the Sim class for physics simulation and supports
52
- various configuration options for randomization, disturbances, and physics models.
51
+ series of gates in a predefined track. It supports various configuration options for
52
+ randomization, disturbances, and physics models.
53
53
54
54
The environment provides:
55
55
- A customizable track with gates and obstacles
56
56
- Configurable simulation and control frequencies
57
- - Support for different physics models (e.g., PyBullet, mathematical dynamics)
57
+ - Support for different physics models (e.g., identified dynamics, analytical dynamics)
58
58
- Randomization of drone properties and initial conditions
59
59
- Disturbance modeling for realistic flight conditions
60
60
- Symbolic expressions for advanced control techniques (optional)
@@ -86,6 +86,7 @@ class MultiDroneRacingEnv(gymnasium.Env):
86
86
87
87
def __init__ (
88
88
self ,
89
+ n_envs : int ,
89
90
n_drones : int ,
90
91
freq : int ,
91
92
sim_config : ConfigDict ,
@@ -111,6 +112,7 @@ def __init__(
111
112
"""
112
113
super ().__init__ ()
113
114
self .sim = Sim (
115
+ n_worlds = n_envs ,
114
116
n_drones = n_drones ,
115
117
physics = sim_config .physics ,
116
118
control = sim_config .get ("control" , "state" ),
@@ -130,25 +132,24 @@ def __init__(
130
132
self .random_resets = random_resets
131
133
self .sensor_range = sensor_range
132
134
self .gates , self .obstacles , self .drone = self .load_track (track )
133
- self .n_gates = len (track .gates )
134
135
specs = {} if disturbances is None else disturbances
135
136
self .disturbances = {mode : rng_spec2fn (spec ) for mode , spec in specs .items ()}
136
137
specs = {} if randomizations is None else randomizations
137
138
self .randomizations = {mode : rng_spec2fn (spec ) for mode , spec in specs .items ()}
138
139
139
140
# Spaces
140
141
self .action_space = spaces .Box (low = - 1 , high = 1 , shape = (n_drones , 13 ))
141
- n_obstacles = len (track .obstacles )
142
+ n_gates , n_obstacles = len ( track . gates ), len (track .obstacles )
142
143
self .observation_space = spaces .Dict (
143
144
{
144
145
"pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
145
146
"rpy" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
146
147
"vel" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
147
148
"ang_vel" : spaces .Box (low = - np .inf , high = np .inf , shape = (3 ,)),
148
- "target_gate" : spaces .Discrete (self . n_gates , start = - 1 ),
149
- "gates_pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (self . n_gates , 3 )),
150
- "gates_rpy" : spaces .Box (low = - np .pi , high = np .pi , shape = (self . n_gates , 3 )),
151
- "gates_visited" : spaces .Box (low = 0 , high = 1 , shape = (self . n_gates ,), dtype = bool ),
149
+ "target_gate" : spaces .Discrete (n_gates , start = - 1 ),
150
+ "gates_pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (n_gates , 3 )),
151
+ "gates_rpy" : spaces .Box (low = - np .pi , high = np .pi , shape = (n_gates , 3 )),
152
+ "gates_visited" : spaces .Box (low = 0 , high = 1 , shape = (n_gates ,), dtype = bool ),
152
153
"obstacles_pos" : spaces .Box (low = - np .inf , high = np .inf , shape = (n_obstacles , 3 )),
153
154
"obstacles_visited" : spaces .Box (low = 0 , high = 1 , shape = (n_obstacles ,), dtype = bool ),
154
155
}
@@ -162,7 +163,7 @@ def __init__(
162
163
self .target_gate = np .zeros (self .sim .n_drones , dtype = int )
163
164
self ._steps = 0
164
165
self ._last_drone_pos = np .zeros ((self .sim .n_drones , 3 ))
165
- self .gates_visited = np .zeros ((self .sim .n_drones , self . n_gates ), dtype = bool )
166
+ self .gates_visited = np .zeros ((self .sim .n_drones , n_gates ), dtype = bool )
166
167
self .obstacles_visited = np .zeros ((self .sim .n_drones , n_obstacles ), dtype = bool )
167
168
168
169
# Compile the reset and step functions with custom hooks
@@ -242,17 +243,18 @@ def step(
242
243
)
243
244
self .sim .data = self .warp_disabled_drones (self .sim .data , self .disabled_drones )
244
245
# TODO: Clean up the accelerated functions
246
+ n_gates = len (self .gates ["pos" ])
247
+ gate_id = self .target_gate % n_gates
245
248
passed = self ._gate_passed (
246
- self . target_gate ,
249
+ gate_id ,
247
250
self .gates ["mocap_ids" ],
248
251
self .sim .data .mjx_data .mocap_pos [0 ],
249
252
self .sim .data .mjx_data .mocap_quat [0 ],
250
253
self .sim .data .states .pos [0 ],
251
254
self ._last_drone_pos ,
252
- self .n_gates ,
253
255
)
254
256
self .target_gate += np .array (passed ) * ~ self .disabled_drones
255
- self .target_gate [self .target_gate >= self . n_gates ] = - 1
257
+ self .target_gate [self .target_gate >= n_gates ] = - 1
256
258
self ._last_drone_pos = self .sim .data .states .pos [0 ]
257
259
return self .obs (), self .reward (), self .terminated (), False , self .info ()
258
260
@@ -397,8 +399,8 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
397
399
398
400
def load_contact_masks (self ) -> NDArray [np .bool_ ]:
399
401
"""Load contact masks for the simulation that zero out irrelevant contacts per drone."""
400
- n_obstacles = len (self .obstacles ["pos" ])
401
- object_contacts = n_obstacles + self . n_gates * 5 + 1 # 5 geoms per gate, 1 for the floor
402
+ n_gates , n_obstacles = len ( self . gates [ "pos" ]), len (self .obstacles ["pos" ])
403
+ object_contacts = n_obstacles + n_gates * 5 + 1 # 5 geoms per gate, 1 for the floor
402
404
drone_contacts = (self .sim .n_drones - 1 ) * self .sim .n_drones // 2
403
405
n_contacts = self .sim .n_drones * object_contacts + drone_contacts
404
406
masks = np .zeros ((self .sim .n_drones , n_contacts ), dtype = bool )
@@ -462,21 +464,20 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
462
464
@staticmethod
463
465
@jax .jit
464
466
def _gate_passed (
465
- target_gate : NDArray ,
467
+ gate_id : int ,
466
468
mocap_ids : NDArray ,
467
469
mocap_pos : Array ,
468
470
mocap_quat : Array ,
469
471
drone_pos : Array ,
470
472
last_drone_pos : NDArray ,
471
- n_gates : int ,
472
473
) -> bool :
473
474
"""Check if the drone has passed a gate.
474
475
475
476
Returns:
476
477
True if the drone has passed a gate, else False.
477
478
"""
478
479
# TODO: Test. Cover cases with no gates.
479
- ids = mocap_ids [target_gate % n_gates ]
480
+ ids = mocap_ids [gate_id ]
480
481
gate_pos = mocap_pos [ids ]
481
482
gate_rot = JaxR .from_quat (mocap_quat [ids ][..., [1 , 2 , 3 , 0 ]])
482
483
gate_size = (0.45 , 0.45 )
0 commit comments