Skip to content

Commit 7a8ef09

Browse files
committed
Accelerate gymnasium envs
1 parent 1a65f13 commit 7a8ef09

File tree

4 files changed

+66
-85
lines changed

4 files changed

+66
-85
lines changed

benchmark/main.py

+11-17
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,25 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
4747
max_episode_steps=200,
4848
return_datatype="numpy",
4949
num_envs=sim_config.n_worlds,
50-
jax_random_key=42,
5150
**sim_config,
5251
)
5352

5453
# Action for going up (in attitude control)
55-
action = np.array(
56-
[[[-0.3, 0, 0, 0] for _ in range(sim_config.n_drones)] for _ in range(sim_config.n_worlds)],
57-
dtype=np.float32,
58-
).reshape(sim_config.n_worlds, -1)
54+
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
55+
action[..., 0] = -0.3
5956

60-
# step through env once to ensure JIT compilation
61-
_, _ = envs.reset_all(seed=42)
62-
_, _, _, _, _ = envs.step(action)
63-
_, _ = envs.reset_all(seed=42)
64-
_, _, _, _, _ = envs.step(action)
65-
_, _ = envs.reset_all(seed=42)
57+
# Step through env once to ensure JIT compilation
58+
envs.reset_all(seed=42)
59+
envs.step(action)
60+
envs.step(action)
6661

67-
jax.block_until_ready(envs.unwrapped.sim._mjx_data) # Ensure JIT compiled dynamics
62+
jax.block_until_ready(envs.unwrapped.sim.states.pos) # Ensure JIT compiled dynamics
6863

6964
# Step through the environment
7065
for _ in range(n_steps):
7166
tstart = time.perf_counter()
72-
_, _, _, _, _ = envs.step(action)
73-
jax.block_until_ready(envs.unwrapped.sim._mjx_data)
67+
envs.step(action)
68+
jax.block_until_ready(envs.unwrapped.sim.states.pos)
7469
times.append(time.perf_counter() - tstart)
7570

7671
envs.close()
@@ -90,14 +85,13 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
9085
sim.reset()
9186
sim.attitude_control(cmd)
9287
sim.step()
93-
sim.reset()
94-
jax.block_until_ready(sim._mjx_data) # Ensure JIT compiled dynamics
88+
jax.block_until_ready(sim.states.pos) # Ensure JIT compiled dynamics
9589

9690
for _ in range(n_steps):
9791
tstart = time.perf_counter()
9892
sim.attitude_control(cmd)
9993
sim.step()
100-
jax.block_until_ready(sim._mjx_data)
94+
jax.block_until_ready(sim.states.pos)
10195
times.append(time.perf_counter() - tstart)
10296

10397
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)

benchmark/performance.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
2626
sim.reset()
2727
control_fn(cmd)
2828
sim.step()
29-
control_fn(cmd)
3029
sim.step()
3130
sim.reset()
3231
jax.block_until_ready(sim.states.pos)
@@ -52,7 +51,6 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
5251
max_episode_steps=200,
5352
return_datatype="numpy",
5453
num_envs=sim_config.n_worlds,
55-
jax_random_key=42,
5654
**sim_config,
5755
)
5856

@@ -61,20 +59,17 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
6159
action[..., 0] = -0.3
6260

6361
# Step through env once to ensure JIT compilation.
64-
# TODO: Currently triggering recompiles also after the first full run. Investigate why and fix
65-
# envs accordingly.
6662
envs.reset_all(seed=42)
67-
68-
for _ in range(envs.max_episode_steps + 1): # Ensure all paths have been taken at least once
69-
envs.step(action)
70-
63+
envs.step(action)
64+
envs.step(action) # Ensure all paths have been taken at least once
7165
envs.reset_all(seed=42)
7266

7367
profiler = Profiler()
7468
profiler.start()
7569

7670
for _ in range(n_steps):
77-
_, _, _, _, _ = envs.step(action)
71+
envs.step(action)
72+
jax.block_until_ready(envs.unwrapped.sim.states.pos)
7873

7974
profiler.stop()
8075
renderer = HTMLRenderer()

crazyflow/gymnasium_envs/crazyflow.py

+48-53
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ class CrazyflowBaseEnv(VectorEnv):
4747
def __init__(
4848
self,
4949
*,
50-
jax_random_key: int, # required for jax random number generator
5150
num_envs: int = 1, # required for VectorEnv
5251
max_episode_steps: int = 1000,
5352
return_datatype: Literal["numpy", "jax"] = "jax",
@@ -56,7 +55,6 @@ def __init__(
5655
"""Summary: Initializes the CrazyflowEnv.
5756
5857
Args:
59-
jax_random_key: The random key for the jax random number generator.
6058
num_envs: The number of environments to run in parallel.
6159
max_episode_steps: The maximum number of steps per episode.
6260
return_datatype: The data type for returned arrays, either "numpy" or "jax". If "numpy",
@@ -66,12 +64,14 @@ def __init__(
6664
"""
6765
assert num_envs == kwargs["n_worlds"], "num_envs must be equal to n_worlds"
6866

69-
self.jax_key = jax.random.key(jax_random_key)
67+
# Set random initial seed for JAX. For seeding, people should use the reset function
68+
jax_seed = int(self.np_random.random() * 2**32)
69+
self.jax_key = jax.random.key(jax_seed)
7070

7171
self.num_envs = num_envs
7272
self.return_datatype = return_datatype
7373
self.device = jax.devices(kwargs["device"])[0]
74-
self.max_episode_steps = jnp.array(max_episode_steps, dtype=jnp.int32, device=self.device)
74+
self.max_episode_steps = max_episode_steps
7575

7676
self.sim = Sim(**kwargs)
7777

@@ -83,7 +83,7 @@ def __init__(
8383
"Simulation frequency should be a multiple of control frequency. We can handle the other case, but we highly recommend to change the simulation frequency to a multiple of the control frequency."
8484
)
8585

86-
self.n_substeps = jnp.array(self.sim.freq // self.sim.control_freq)
86+
self.n_substeps = self.sim.freq // self.sim.control_freq
8787

8888
self.prev_done = jnp.zeros((self.sim.n_worlds), dtype=jnp.bool_, device=self.device)
8989

@@ -111,44 +111,40 @@ def __init__(
111111

112112
def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
113113
assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid"
114-
action = jnp.array(action, device=self.device).reshape(
115-
(self.sim.n_worlds, self.sim.n_drones, -1)
116-
)
117-
114+
action = self._sanitize_action(action, self.sim.n_worlds, self.sim.n_drones, self.device)
118115
action = self._rescale_action(action, self.sim.control)
119116

120-
if self.sim.control == Control.state:
121-
raise NotImplementedError(
122-
"Possibly you want to control state differences instead of absolute states"
123-
)
124-
self.sim.state_control(action)
125-
elif self.sim.control == Control.attitude:
126-
self.sim.attitude_control(action)
127-
elif self.sim.control == Control.thrust:
128-
self.sim.thrust_control(action)
129-
else:
130-
raise ValueError(f"Invalid control type {self.sim.control}")
117+
match self.sim.control:
118+
case Control.state:
119+
raise NotImplementedError(
120+
"Possibly you want to control state differences instead of absolute states"
121+
)
122+
case Control.attitude:
123+
self.sim.attitude_control(action)
124+
case Control.thrust:
125+
self.sim.thrust_control(action)
126+
case _:
127+
raise ValueError(f"Invalid control type {self.sim.control}")
131128

132129
for _ in range(self.n_substeps):
133130
self.sim.step()
134-
135131
# Reset all environments which terminated or were truncated in the last step
136132
if jnp.any(self.prev_done):
137133
self.reset(mask=self.prev_done)
138134

139-
reward = self.reward
140135
terminated = self.terminated
141136
truncated = self.truncated
137+
self.prev_done = self._done(terminated, truncated)
142138

143-
self.prev_done = jnp.logical_or(terminated, truncated)
139+
convert = self.return_datatype == "numpy"
140+
terminated = maybe_to_numpy(terminated, convert)
141+
truncated = maybe_to_numpy(truncated, convert)
142+
return self._obs(), self.reward, terminated, truncated, {}
144143

145-
return (
146-
self._get_obs(),
147-
reward,
148-
maybe_to_numpy(terminated, self.return_datatype == "numpy"),
149-
maybe_to_numpy(truncated, self.return_datatype == "numpy"),
150-
{},
151-
)
144+
@staticmethod
145+
@partial(jax.jit, static_argnames=["n_worlds", "n_drones", "device"])
146+
def _sanitize_action(action: Array, n_worlds: int, n_drones: int, device: str) -> Array:
147+
return jnp.array(action, device=device).reshape((n_worlds, n_drones, -1))
152148

153149
@staticmethod
154150
@partial(jax.jit, static_argnames=["control_type"])
@@ -167,14 +163,19 @@ def _rescale_action(action: Array, control_type: str) -> Array:
167163
raise NotImplementedError(
168164
f"Rescaling not implemented for control type '{control_type}'"
169165
)
170-
171166
return action * params.scale_factor + params.mean
172167

168+
@staticmethod
169+
@jax.jit
170+
def _done(terminated: Array, truncated: Array) -> Array:
171+
return jnp.logical_or(terminated, truncated)
172+
173173
def reset_all(
174174
self, *, seed: int | None = None, options: dict | None = None
175175
) -> tuple[dict[str, Array], dict]:
176176
super().reset(seed=seed)
177-
177+
if seed is not None:
178+
self.jax_key = jax.random.key(seed)
178179
# Resets ALL (!) environments
179180
if options is None:
180181
options = {}
@@ -183,7 +184,7 @@ def reset_all(
183184

184185
self.prev_done = jnp.zeros((self.sim.n_worlds), dtype=jnp.bool_)
185186

186-
return self._get_obs(), {}
187+
return self._obs(), {}
187188

188189
def reset(self, mask: Array) -> None:
189190
self.sim.reset(mask=mask)
@@ -241,26 +242,21 @@ def _terminated(dones: Array, states: SimState, contacts: Array) -> Array:
241242
return jnp.where(dones, False, terminated)
242243

243244
@staticmethod
244-
@jax.jit
245-
def _truncated(
246-
dones: Array, steps: Array, max_episode_steps: Array, n_substeps: Array
247-
) -> Array:
245+
@partial(jax.jit, static_argnames=["max_episode_steps", "n_substeps"])
246+
def _truncated(dones: Array, steps: Array, max_episode_steps: int, n_substeps: int) -> Array:
248247
truncated = steps / n_substeps >= max_episode_steps
249248
return jnp.where(dones, False, truncated)
250249

251250
def render(self):
252251
self.sim.render()
253252

254-
def _get_obs(self) -> dict[str, Array]:
255-
obs = {
256-
state: maybe_to_numpy(
257-
getattr(self.sim.states, state)[..., 2]
258-
if state == "pos"
259-
else getattr(self.sim.states, state),
260-
self.return_datatype == "numpy",
261-
)
262-
for state in self.states_to_include_in_obs
263-
}
253+
def _obs(self) -> dict[str, Array]:
254+
convert = self.return_datatype == "numpy"
255+
fields = self.states_to_include_in_obs
256+
states = [maybe_to_numpy(getattr(self.sim.states, field), convert) for field in fields]
257+
obs = {k: v for k, v in zip(fields, states)}
258+
if "pos" in obs:
259+
obs["pos"] = obs["pos"][..., 2]
264260
return obs
265261

266262

@@ -276,8 +272,7 @@ def __init__(self, **kwargs: dict):
276272
-jnp.inf, jnp.inf, shape=(self._obs_size,), dtype=jnp.float32
277273
)
278274
self.observation_space = batch_space(self.single_observation_space, self.sim.n_worlds)
279-
280-
self.goal = jnp.zeros((kwargs["n_worlds"], 3), dtype=jnp.float32)
275+
self.goal = jnp.zeros((kwargs["n_worlds"], 3), dtype=jnp.float32, device=self.device)
281276

282277
@property
283278
def reward(self) -> Array:
@@ -303,8 +298,8 @@ def reset(self, mask: Array) -> None:
303298
)
304299
self.goal = self.goal.at[mask].set(new_goals[mask])
305300

306-
def _get_obs(self) -> dict[str, Array]:
307-
obs = super()._get_obs()
301+
def _obs(self) -> dict[str, Array]:
302+
obs = super()._obs()
308303
obs["difference_to_goal"] = [self.goal - self.sim.states.pos]
309304
return obs
310305

@@ -348,7 +343,7 @@ def reset(self, mask: Array) -> None:
348343
)
349344
self.target_vel = self.target_vel.at[mask].set(new_target_vel[mask])
350345

351-
def _get_obs(self) -> dict[str, Array]:
352-
obs = super()._get_obs()
346+
def _obs(self) -> dict[str, Array]:
347+
obs = super()._obs()
353348
obs["difference_to_target_vel"] = [self.target_vel - self.sim.states.vel]
354349
return obs

examples/gymnasium_env.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818
SEED = 42
1919

2020
envs = gymnasium.make_vec(
21-
"CrazyflowEnvReachGoal-v0",
21+
"DroneReachPos-v0",
2222
max_episode_steps=1000,
2323
return_datatype="numpy",
2424
num_envs=sim_config.n_worlds,
25-
jax_random_key=SEED,
2625
**sim_config,
2726
)
2827

2928
# action for going up (in attitude control). NOTE actions are rescaled in the environment
30-
action = np.array(
31-
[[[-0.2, 0, 0, 0] for _ in range(sim_config.n_drones)] for _ in range(sim_config.n_worlds)],
32-
dtype=np.float32,
33-
).reshape(sim_config.n_worlds, -1)
29+
action = np.zeros((sim_config.n_worlds * sim_config.n_drones, 4), dtype=np.float32)
30+
action[..., 0] = -0.2
3431

3532
obs, info = envs.reset_all(seed=SEED)
3633

0 commit comments

Comments
 (0)