Skip to content

Commit ee8efea

Browse files
committed
[wip,broken] Add vectorized multi-drone env
1 parent 26141b0 commit ee8efea

File tree

5 files changed

+683
-23
lines changed

5 files changed

+683
-23
lines changed

benchmarks/sim.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import lsy_drone_racing
7474
7575
env = gymnasium.make('MultiDroneRacing-v0',
76+
n_envs=1, # TODO: Remove this for single-world envs
7677
n_drones=config.env.n_drones,
7778
freq=config.env.freq,
7879
sim_config=config.sim,

lsy_drone_racing/envs/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@
4949

5050
# region MultiEnvs
5151

52+
# TODO: Register specialized, non-vectorized envs for single worlds
5253
register(
5354
id="MultiDroneRacing-v0",
54-
entry_point="lsy_drone_racing.envs.multi_drone_race:MultiDroneRacingEnv",
55+
entry_point="lsy_drone_racing.envs.vec_drone_race:VectorMultiDroneRaceEnv",
5556
max_episode_steps=1800,
5657
disable_env_checker=True,
5758
)

lsy_drone_racing/envs/multi_drone_race.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ class MultiDroneRacingEnv(gymnasium.Env):
4848
"""A Gymnasium environment for drone racing simulations.
4949
5050
This environment simulates a drone racing scenario where a single drone navigates through a
51-
series of gates in a predefined track. It uses the Sim class for physics simulation and supports
52-
various configuration options for randomization, disturbances, and physics models.
51+
series of gates in a predefined track. It supports various configuration options for
52+
randomization, disturbances, and physics models.
5353
5454
The environment provides:
5555
- A customizable track with gates and obstacles
5656
- Configurable simulation and control frequencies
57-
- Support for different physics models (e.g., PyBullet, mathematical dynamics)
57+
- Support for different physics models (e.g., identified dynamics, analytical dynamics)
5858
- Randomization of drone properties and initial conditions
5959
- Disturbance modeling for realistic flight conditions
6060
- Symbolic expressions for advanced control techniques (optional)
@@ -86,6 +86,7 @@ class MultiDroneRacingEnv(gymnasium.Env):
8686

8787
def __init__(
8888
self,
89+
n_envs: int,
8990
n_drones: int,
9091
freq: int,
9192
sim_config: ConfigDict,
@@ -111,6 +112,7 @@ def __init__(
111112
"""
112113
super().__init__()
113114
self.sim = Sim(
115+
n_worlds=n_envs,
114116
n_drones=n_drones,
115117
physics=sim_config.physics,
116118
control=sim_config.get("control", "state"),
@@ -130,25 +132,24 @@ def __init__(
130132
self.random_resets = random_resets
131133
self.sensor_range = sensor_range
132134
self.gates, self.obstacles, self.drone = self.load_track(track)
133-
self.n_gates = len(track.gates)
134135
specs = {} if disturbances is None else disturbances
135136
self.disturbances = {mode: rng_spec2fn(spec) for mode, spec in specs.items()}
136137
specs = {} if randomizations is None else randomizations
137138
self.randomizations = {mode: rng_spec2fn(spec) for mode, spec in specs.items()}
138139

139140
# Spaces
140141
self.action_space = spaces.Box(low=-1, high=1, shape=(n_drones, 13))
141-
n_obstacles = len(track.obstacles)
142+
n_gates, n_obstacles = len(track.gates), len(track.obstacles)
142143
self.observation_space = spaces.Dict(
143144
{
144145
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
145146
"rpy": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
146147
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
147148
"ang_vel": spaces.Box(low=-np.inf, high=np.inf, shape=(3,)),
148-
"target_gate": spaces.Discrete(self.n_gates, start=-1),
149-
"gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(self.n_gates, 3)),
150-
"gates_rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(self.n_gates, 3)),
151-
"gates_visited": spaces.Box(low=0, high=1, shape=(self.n_gates,), dtype=bool),
149+
"target_gate": spaces.Discrete(n_gates, start=-1),
150+
"gates_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_gates, 3)),
151+
"gates_rpy": spaces.Box(low=-np.pi, high=np.pi, shape=(n_gates, 3)),
152+
"gates_visited": spaces.Box(low=0, high=1, shape=(n_gates,), dtype=bool),
152153
"obstacles_pos": spaces.Box(low=-np.inf, high=np.inf, shape=(n_obstacles, 3)),
153154
"obstacles_visited": spaces.Box(low=0, high=1, shape=(n_obstacles,), dtype=bool),
154155
}
@@ -162,7 +163,7 @@ def __init__(
162163
self.target_gate = np.zeros(self.sim.n_drones, dtype=int)
163164
self._steps = 0
164165
self._last_drone_pos = np.zeros((self.sim.n_drones, 3))
165-
self.gates_visited = np.zeros((self.sim.n_drones, self.n_gates), dtype=bool)
166+
self.gates_visited = np.zeros((self.sim.n_drones, n_gates), dtype=bool)
166167
self.obstacles_visited = np.zeros((self.sim.n_drones, n_obstacles), dtype=bool)
167168

168169
# Compile the reset and step functions with custom hooks
@@ -242,17 +243,18 @@ def step(
242243
)
243244
self.sim.data = self.warp_disabled_drones(self.sim.data, self.disabled_drones)
244245
# TODO: Clean up the accelerated functions
246+
n_gates = len(self.gates["pos"])
247+
gate_id = self.target_gate % n_gates
245248
passed = self._gate_passed(
246-
self.target_gate,
249+
gate_id,
247250
self.gates["mocap_ids"],
248251
self.sim.data.mjx_data.mocap_pos[0],
249252
self.sim.data.mjx_data.mocap_quat[0],
250253
self.sim.data.states.pos[0],
251254
self._last_drone_pos,
252-
self.n_gates,
253255
)
254256
self.target_gate += np.array(passed) * ~self.disabled_drones
255-
self.target_gate[self.target_gate >= self.n_gates] = -1
257+
self.target_gate[self.target_gate >= n_gates] = -1
256258
self._last_drone_pos = self.sim.data.states.pos[0]
257259
return self.obs(), self.reward(), self.terminated(), False, self.info()
258260

@@ -397,8 +399,8 @@ def load_track(self, track: dict) -> tuple[dict, dict, dict]:
397399

398400
def load_contact_masks(self) -> NDArray[np.bool_]:
399401
"""Load contact masks for the simulation that zero out irrelevant contacts per drone."""
400-
n_obstacles = len(self.obstacles["pos"])
401-
object_contacts = n_obstacles + self.n_gates * 5 + 1 # 5 geoms per gate, 1 for the floor
402+
n_gates, n_obstacles = len(self.gates["pos"]), len(self.obstacles["pos"])
403+
object_contacts = n_obstacles + n_gates * 5 + 1 # 5 geoms per gate, 1 for the floor
402404
drone_contacts = (self.sim.n_drones - 1) * self.sim.n_drones // 2
403405
n_contacts = self.sim.n_drones * object_contacts + drone_contacts
404406
masks = np.zeros((self.sim.n_drones, n_contacts), dtype=bool)
@@ -462,21 +464,20 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
462464
@staticmethod
463465
@jax.jit
464466
def _gate_passed(
465-
target_gate: NDArray,
467+
gate_id: int,
466468
mocap_ids: NDArray,
467469
mocap_pos: Array,
468470
mocap_quat: Array,
469471
drone_pos: Array,
470472
last_drone_pos: NDArray,
471-
n_gates: int,
472473
) -> bool:
473474
"""Check if the drone has passed a gate.
474475
475476
Returns:
476477
True if the drone has passed a gate, else False.
477478
"""
478479
# TODO: Test. Cover cases with no gates.
479-
ids = mocap_ids[target_gate % n_gates]
480+
ids = mocap_ids[gate_id]
480481
gate_pos = mocap_pos[ids]
481482
gate_rot = JaxR.from_quat(mocap_quat[ids][..., [1, 2, 3, 0]])
482483
gate_size = (0.45, 0.45)

0 commit comments

Comments
 (0)