Skip to content

Commit 86f40de

Browse files
committed
Fix name collisions between core env and gym envs. Improve benchmarking
1 parent 53d8b36 commit 86f40de

File tree

5 files changed

+125
-55
lines changed

5 files changed

+125
-55
lines changed

benchmarks/main.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
from sim import time_multi_drone_reset, time_multi_drone_step, time_sim_reset, time_sim_step
66

77

8-
def print_benchmark_results(name: str, timings: list[float]):
9-
print(f"\nResults for {name}:")
8+
def print_benchmark_results(name: str, timings: list[float], n_envs: int, device: str):
9+
print(f"\nResults for {name} ({n_envs} envs, {device}):")
1010
print(f"Mean/std: {np.mean(timings):.2e}s +- {np.std(timings):.2e}s")
1111
print(f"Min time: {np.min(timings):.2e}s")
1212
print(f"Max time: {np.max(timings):.2e}s")
13-
print(f"FPS: {1 / np.mean(timings):.2f}")
13+
print(f"FPS: {n_envs / np.mean(timings):.2f}")
1414

1515

1616
def main(
@@ -19,18 +19,31 @@ def main(
1919
multi_drone: bool = False,
2020
reset: bool = True,
2121
step: bool = True,
22+
vec_size: int = 1,
23+
device: str = "cpu",
2224
):
2325
reset_fn, step_fn = time_sim_reset, time_sim_step
2426
if multi_drone:
2527
reset_fn, step_fn = time_multi_drone_reset, time_multi_drone_step
2628
if reset:
27-
timings = reset_fn(n_tests=n_tests, number=number)
28-
print_benchmark_results(name="Racing env reset", timings=timings / number)
29+
timings = reset_fn(n_tests=n_tests, number=number, n_envs=vec_size, device=device)
30+
print_benchmark_results(
31+
name="Racing env reset", timings=timings / number, n_envs=vec_size, device=device
32+
)
2933
if step:
30-
timings = step_fn(n_tests=n_tests, number=number)
31-
print_benchmark_results(name="Racing env steps", timings=timings / number)
32-
timings = step_fn(n_tests=n_tests, number=number, physics_mode="sys_id")
33-
print_benchmark_results(name="Racing env steps (sys_id backend)", timings=timings / number)
34+
timings = step_fn(n_tests=n_tests, number=number, n_envs=vec_size, device=device)
35+
print_benchmark_results(
36+
name="Racing env steps", timings=timings / number, n_envs=vec_size, device=device
37+
)
38+
timings = step_fn(
39+
n_tests=n_tests, number=number, physics_mode="sys_id", n_envs=vec_size, device=device
40+
)
41+
print_benchmark_results(
42+
name="Racing env steps (sys_id backend)",
43+
timings=timings / number,
44+
n_envs=vec_size,
45+
device=device,
46+
)
3447
# timings = step_fn(n_tests=n_tests, number=number, physics_mode="mujoco")
3548
# print_benchmark_results(name="Sim steps (mujoco backend)", timings=timings / number)
3649

benchmarks/sim.py

+66-33
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
env_setup_code = """
2121
import gymnasium
22+
import jax
2223
2324
import lsy_drone_racing
2425
25-
env = gymnasium.make(
26+
env = gymnasium.make_vec(
2627
config.env.id,
28+
num_envs={num_envs},
2729
freq=config.env.freq,
2830
sim_config=config.sim,
2931
sensor_range=config.env.sensor_range,
@@ -32,21 +34,29 @@
3234
randomizations=config.env.get("randomizations"),
3335
random_resets=config.env.random_resets,
3436
seed=config.env.seed,
37+
device='{device}',
3538
)
39+
40+
# JIT compile the reset and step functions
3641
env.reset()
37-
env.step(env.action_space.sample()) # JIT compile
38-
env.reset()
39-
env.action_space.seed(42)
40-
action = env.action_space.sample()
42+
env.step(env.action_space.sample())
43+
jax.block_until_ready(env.unwrapped.data)
44+
# JIT masked reset (used in autoreset)
45+
mask = env.unwrapped.data.marked_for_reset
46+
mask = mask.at[0].set(True)
47+
env.unwrapped._reset(mask=mask) # enforce masked reset compile
48+
jax.block_until_ready(env.unwrapped.data)
49+
env.action_space.seed(2)
4150
"""
4251

4352
attitude_env_setup_code = """
4453
import gymnasium
54+
import jax
4555
4656
import lsy_drone_racing
4757
48-
env = gymnasium.make('DroneRacingAttitude-v0',
49-
config.env.id,
58+
env = gymnasium.make_vec('DroneRacingAttitude-v0',
59+
num_envs={num_envs},
5060
freq=config.env.freq,
5161
sim_config=config.sim,
5262
sensor_range=config.env.sensor_range,
@@ -55,12 +65,19 @@
5565
randomizations=config.env.get("randomizations"),
5666
random_resets=config.env.random_resets,
5767
seed=config.env.seed,
68+
device='{device}',
5869
)
70+
71+
# JIT compile the reset and step functions
5972
env.reset()
60-
env.step(env.action_space.sample()) # JIT compile
61-
env.reset()
62-
env.action_space.seed(42)
63-
action = env.action_space.sample()
73+
env.step(env.action_space.sample())
74+
jax.block_until_ready(env.unwrapped.data)
75+
# JIT masked reset (used in autoreset)
76+
mask = env.unwrapped.data.marked_for_reset
77+
mask = mask.at[0].set(True)
78+
env.unwrapped._reset(mask=mask) # enforce masked reset compile
79+
jax.block_until_ready(env.unwrapped.data)
80+
env.action_space.seed(2)
6481
"""
6582

6683
load_multi_drone_config_code = f"""
@@ -76,9 +93,11 @@
7693
import jax
7794
7895
import lsy_drone_racing
79-
from lsy_drone_racing.envs.multi_drone_race import MultiDroneRaceEnv
96+
from lsy_drone_racing.envs.multi_drone_race import VecMultiDroneRaceEnv
8097
81-
env = gymnasium.make('MultiDroneRacing-v0',
98+
99+
env = gymnasium.make_vec('MultiDroneRacing-v0',
100+
num_envs={num_envs},
82101
n_drones=config.env.n_drones,
83102
freq=config.env.freq,
84103
sim_config=config.sim,
@@ -88,58 +107,72 @@
88107
randomizations=config.env.get("randomizations"),
89108
random_resets=config.env.random_resets,
90109
seed=config.env.seed,
91-
device='cpu',
110+
device='{device}',
92111
)
93112
113+
# JIT compile the reset and step functions
94114
env.reset()
95-
# JIT step
96115
env.step(env.action_space.sample())
97116
jax.block_until_ready(env.unwrapped.data)
98117
# JIT masked reset (used in autoreset)
99118
mask = env.unwrapped.data.marked_for_reset
100119
mask = mask.at[0].set(True)
101-
super(MultiDroneRaceEnv, env.unwrapped).reset(mask=mask) # enforce masked reset compile
120+
env.unwrapped._reset(mask=mask) # enforce masked reset compile
102121
jax.block_until_ready(env.unwrapped.data)
103122
env.action_space.seed(2)
104123
"""
105124

106125

107-
def time_sim_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
108-
setup = load_config_code + env_setup_code
126+
def time_sim_reset(
127+
n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu"
128+
) -> NDArray[np.floating]:
129+
setup = load_config_code + env_setup_code.format(num_envs=n_envs, device=device)
109130
stmt = """env.reset()"""
110131
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
111132

112133

113134
def time_sim_step(
114-
n_tests: int = 10, number: int = 1, physics_mode: str = "analytical"
135+
n_tests: int = 10,
136+
number: int = 1,
137+
physics_mode: str = "analytical",
138+
n_envs: int = 1,
139+
device: str = "cpu",
115140
) -> NDArray[np.floating]:
116141
modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
117-
setup = load_config_code + modify_config_code + env_setup_code + "\nenv.reset()"
118-
stmt = """env.step(action)"""
142+
_env_setup_code = env_setup_code.format(num_envs=n_envs, device=device)
143+
setup = load_config_code + modify_config_code + _env_setup_code + "\nenv.reset()"
144+
stmt = """env.step(env.action_space.sample())"""
119145
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
120146

121147

122-
def time_sim_attitude_step(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
123-
setup = load_config_code + attitude_env_setup_code + "\nenv.reset()"
124-
stmt = """env.step(action)"""
148+
def time_sim_attitude_step(
149+
n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu"
150+
) -> NDArray[np.floating]:
151+
env_setup_code = attitude_env_setup_code.format(num_envs=n_envs, device=device)
152+
setup = load_config_code + env_setup_code + "\nenv.reset()"
153+
stmt = """env.step(env.action_space.sample())"""
125154
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
126155

127156

128-
def time_multi_drone_reset(n_tests: int = 10, number: int = 1) -> NDArray[np.floating]:
129-
setup = load_multi_drone_config_code + multi_drone_env_setup_code + "\nenv.reset()"
157+
def time_multi_drone_reset(
158+
n_tests: int = 10, number: int = 1, n_envs: int = 1, device: str = "cpu"
159+
) -> NDArray[np.floating]:
160+
env_setup_code = multi_drone_env_setup_code.format(num_envs=n_envs, device=device)
161+
setup = load_multi_drone_config_code + env_setup_code + "\nenv.reset()"
130162
stmt = """env.reset()"""
131163
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))
132164

133165

134166
def time_multi_drone_step(
135-
n_tests: int = 10, number: int = 100, physics_mode: str = "analytical"
167+
n_tests: int = 10,
168+
number: int = 100,
169+
physics_mode: str = "analytical",
170+
n_envs: int = 1,
171+
device: str = "cpu",
136172
) -> NDArray[np.floating]:
137173
modify_config_code = f"""config.sim.physics = '{physics_mode}'\n"""
138-
setup = (
139-
load_multi_drone_config_code
140-
+ modify_config_code
141-
+ multi_drone_env_setup_code
142-
+ "\nenv.reset()"
143-
)
174+
env_setup_code = multi_drone_env_setup_code.format(num_envs=n_envs, device=device)
175+
176+
setup = load_multi_drone_config_code + modify_config_code + env_setup_code + "\nenv.reset()"
144177
stmt = """env.step(env.action_space.sample())"""
145178
return np.array(timeit.repeat(stmt=stmt, setup=setup, number=number, repeat=n_tests))

lsy_drone_racing/envs/drone_race.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
7878
Returns:
7979
The initial observation and info.
8080
"""
81-
obs, info = super().reset(seed=seed, options=options)
81+
obs, info = self._reset(seed=seed, options=options)
8282
obs = {k: v[0, 0] for k, v in obs.items()}
8383
info = {k: v[0, 0] for k, v in info.items()}
8484
return obs, info
@@ -92,7 +92,7 @@ def step(self, action: NDArray[np.floating]) -> tuple[dict, float, bool, bool, d
9292
Returns:
9393
Observation, reward, terminated, truncated, and info.
9494
"""
95-
obs, reward, terminated, truncated, info = super().step(action)
95+
obs, reward, terminated, truncated, info = self._step(action)
9696
obs = {k: v[0, 0] for k, v in obs.items()}
9797
info = {k: v[0, 0] for k, v in info.items()}
9898
return obs, float(reward[0, 0]), bool(terminated[0, 0]), bool(truncated[0, 0]), info
@@ -163,7 +163,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
163163
Returns:
164164
The initial observation and info.
165165
"""
166-
obs, info = super().reset(seed=seed, options=options)
166+
obs, info = self._reset(seed=seed, options=options)
167167
obs = {k: v[:, 0] for k, v in obs.items()}
168168
info = {k: v[:, 0] for k, v in info.items()}
169169
return obs, info
@@ -179,7 +179,7 @@ def step(
179179
Returns:
180180
Observation, reward, terminated, truncated, and info.
181181
"""
182-
obs, reward, terminated, truncated, info = super().step(action)
182+
obs, reward, terminated, truncated, info = self._step(action)
183183
obs = {k: v[:, 0] for k, v in obs.items()}
184184
info = {k: v[:, 0] for k, v in info.items()}
185185
return obs, reward[:, 0], terminated[:, 0], truncated[:, 0], info

lsy_drone_racing/envs/multi_drone_race.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[d
8282
Returns:
8383
Observation and info for all drones.
8484
"""
85-
obs, info = super().reset(seed=seed, options=options)
85+
obs, info = self._reset(seed=seed, options=options)
8686
obs = {k: v[0] for k, v in obs.items()}
8787
info = {k: v[0] for k, v in info.items()}
8888
return obs, info
@@ -98,7 +98,7 @@ def step(
9898
Returns:
9999
Observation, reward, terminated, truncated, and info for all drones.
100100
"""
101-
obs, reward, terminated, truncated, info = super().step(action)
101+
obs, reward, terminated, truncated, info = self._step(action)
102102
obs = {k: v[0] for k, v in obs.items()}
103103
info = {k: v[0] for k, v in info.items()}
104104
return obs, reward[0], terminated[0], truncated[0], info
@@ -165,3 +165,25 @@ def __init__(
165165
build_observation_space(n_gates, n_obstacles), n_drones
166166
)
167167
self.observation_space = batch_space(batch_space(self.single_observation_space), num_envs)
168+
169+
def reset(self, seed: int | None = None, options: dict | None = None) -> tuple[dict, dict]:
170+
"""Reset the environment for all drones.
171+
172+
Args:
173+
seed: Random seed.
174+
options: Additional reset options. Not used.
175+
176+
Returns:
177+
Observation and info for all drones.
178+
"""
179+
return self._reset(seed=seed, options=options)
180+
181+
def step(
182+
self, action: NDArray[np.floating]
183+
) -> tuple[dict, NDArray[np.floating], NDArray[np.bool_], NDArray[np.bool_], dict]:
184+
"""Step the environment for all drones.
185+
186+
Args:
187+
action: Action for all drones, i.e., a batch of (n_drones, action_dim) arrays.
188+
"""
189+
return self._step(action)

lsy_drone_racing/envs/race_core.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(
270270
self.device,
271271
)
272272

273-
def reset(
273+
def _reset(
274274
self, *, seed: int | None = None, options: dict | None = None, mask: Array | None = None
275275
) -> tuple[dict[str, NDArray[np.floating]], dict]:
276276
"""Reset the environment.
@@ -292,10 +292,10 @@ def reset(
292292
# Randomization of gates, obstacles and drones is compiled into the sim reset function with
293293
# the sim.reset_hook function, so we don't need to explicitly do it here
294294
self.sim.reset(mask=mask)
295-
self.data = self._reset(self.data, self.sim.data.states.pos, mask)
295+
self.data = self._reset_env_data(self.data, self.sim.data.states.pos, mask)
296296
return self.obs(), self.info()
297297

298-
def step(
298+
def _step(
299299
self, action: NDArray[np.floating]
300300
) -> tuple[dict[str, NDArray[np.floating]], float, bool, bool, dict]:
301301
"""Step the firmware_wrapper class and its environment.
@@ -321,10 +321,12 @@ def step(
321321
# previous flags, not the ones from the current step
322322
marked_for_reset = self.data.marked_for_reset
323323
# Apply the environment logic with updated simulation data.
324-
self.data = self._step(self.data, drone_pos, drone_quat, mocap_pos, mocap_quat, contacts)
324+
self.data = self._step_env(
325+
self.data, drone_pos, drone_quat, mocap_pos, mocap_quat, contacts
326+
)
325327
# Auto-reset envs. Add configuration option to disable for single-world envs
326328
if self.autoreset and marked_for_reset.any():
327-
self.reset(mask=marked_for_reset)
329+
self._reset(mask=marked_for_reset)
328330
return self.obs(), self.reward(), self.terminated(), self.truncated(), self.info()
329331

330332
def apply_action(self, action: NDArray[np.floating]):
@@ -422,7 +424,7 @@ def symbolic_model(self) -> SymbolicModel:
422424

423425
@staticmethod
424426
@jax.jit
425-
def _reset(data: EnvData, drone_pos: Array, mask: Array | None = None) -> EnvData:
427+
def _reset_env_data(data: EnvData, drone_pos: Array, mask: Array | None = None) -> EnvData:
426428
"""Reset auxiliary variables of the environment data."""
427429
mask = jp.ones(data.steps.shape, dtype=bool) if mask is None else mask
428430
target_gate = jp.where(mask[..., None], 0, data.target_gate)
@@ -443,7 +445,7 @@ def _reset(data: EnvData, drone_pos: Array, mask: Array | None = None) -> EnvDat
443445

444446
@staticmethod
445447
@jax.jit
446-
def _step(
448+
def _step_env(
447449
data: EnvData,
448450
drone_pos: Array,
449451
drone_quat: Array,

0 commit comments

Comments
 (0)