Skip to content

Commit 1bb1b94

Browse files
committed
Merge branch 'main' into oliver_symbolic
2 parents 9dc1ea7 + 1a65f13 commit 1bb1b94

File tree

7 files changed

+90
-63
lines changed

7 files changed

+90
-63
lines changed
File renamed without changes.

benchmark/main.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
2121
# Check for significant variance
2222
if tmax / tmin > 5:
2323
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
24-
print(f"Times: max {tmax:.2e}@{idx_tmax}, min {tmin:.2e}@{idx_tmin}")
24+
print(f"Times: max {tmax:.2e} @ {idx_tmax}, min {tmin:.2e} @ {idx_tmin}")
2525

2626
# Performance metrics
2727
n_frames = n_steps * n_worlds # Number of frames simulated
@@ -43,7 +43,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
4343
device = jax.devices(device)[0]
4444

4545
envs = gymnasium.make_vec(
46-
"CrazyflowEnvReachGoal-v0",
46+
"DroneReachPos-v0",
4747
max_episode_steps=200,
4848
return_datatype="numpy",
4949
num_envs=sim_config.n_worlds,
@@ -114,10 +114,10 @@ def main():
114114
sim_config.controller = "emulatefirmware"
115115
sim_config.device = device
116116

117-
print("SIM PERFORMANCE")
117+
print("Simulator performance")
118118
profile_step(sim_config, 100, device)
119119

120-
print("\nGYM ENV PERFORMANCE")
120+
print("\nGymnasium environment performance")
121121
profile_gym_env_step(sim_config, 100, device)
122122

123123

benchmark/performance.py

+22-29
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
15
import gymnasium
26
import jax
37
import numpy as np
@@ -8,6 +12,9 @@
812
import crazyflow # noqa: F401, ensure gymnasium envs are registered
913
from crazyflow.sim.core import Sim
1014

15+
if TYPE_CHECKING:
16+
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal
17+
1118

1219
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
1320
sim = Sim(**sim_config)
@@ -40,8 +47,8 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4047
def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4148
device = jax.devices(device)[0]
4249

43-
envs = gymnasium.make_vec(
44-
"CrazyflowEnvReachGoal-v0",
50+
envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
51+
"DroneReachPos-v0",
4552
max_episode_steps=200,
4653
return_datatype="numpy",
4754
num_envs=sim_config.n_worlds,
@@ -50,28 +57,25 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
5057
)
5158

5259
# Action for going up (in attitude control)
53-
action = np.array(
54-
[[[-0.3, 0, 0, 0] for _ in range(sim_config.n_drones)] for _ in range(sim_config.n_worlds)],
55-
dtype=np.float32,
56-
).reshape(sim_config.n_worlds, -1)
57-
58-
# step through env once to ensure JIT compilation
59-
_, _ = envs.reset_all(seed=42)
60-
_, _, _, _, _ = envs.step(action)
61-
_, _ = envs.reset_all(seed=42)
62-
_, _, _, _, _ = envs.step(action)
63-
_, _ = envs.reset_all(seed=42)
64-
_, _, _, _, _ = envs.step(action)
65-
_, _ = envs.reset_all(seed=42)
66-
67-
jax.block_until_ready(envs.unwrapped.sim._mjx_data) # Ensure JIT compiled dynamics
60+
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
61+
action[..., 0] = -0.3
62+
63+
# Step through env once to ensure JIT compilation.
64+
# TODO: Currently triggering recompiles also after the first full run. Investigate why and fix
65+
# envs accordingly.
66+
envs.reset_all(seed=42)
67+
68+
for _ in range(envs.max_episode_steps + 1): # Ensure all paths have been taken at least once
69+
envs.step(action)
70+
71+
envs.reset_all(seed=42)
6872

6973
profiler = Profiler()
7074
profiler.start()
7175

7276
for _ in range(n_steps):
7377
_, _, _, _, _ = envs.step(action)
74-
jax.block_until_ready(envs.unwrapped.sim._mjx_data)
78+
7579
profiler.stop()
7680
renderer = HTMLRenderer()
7781
renderer.open_in_browser(profiler.last_session)
@@ -89,17 +93,6 @@ def main():
8993
sim_config.device = device
9094

9195
profile_step(sim_config, 1000, device)
92-
# old | new
93-
# sys_id + attitude:
94-
# 0.61 reset, 0.61 step | 0.61 reset, 0.61 step
95-
# sys_id + state:
96-
# 14.53 step, 0.53 reset | 0.75 reset, 0.88 step
97-
98-
# Analytical + attitude:
99-
# 0.75 reset, 9.38 step | 0.75 reset, 0.80 step
100-
# Analytical + state:
101-
# 0.75 reset, 15.1 step | 0.75 reset, 0.82 step
102-
10396
profile_gym_env_step(sim_config, 1000, device)
10497

10598

crazyflow/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import crazyflow.gymnasium_envs # noqa: F401, ensure gymnasium envs are registered

crazyflow/gymnasium_envs/__init__.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from gymnasium.envs.registration import register
22

3+
from crazyflow.gymnasium_envs.crazyflow import CrazyflowEnvReachGoal, CrazyflowEnvTargetVelocity
4+
5+
__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity"]
6+
37
register(
4-
id="CrazyflowEnvReachGoal-v0",
5-
vector_entry_point="crazyflow.gymnasium_envs:CrazyflowEnvReachGoal",
8+
id="DroneReachPos-v0",
9+
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvReachGoal",
610
)
711

812
register(
9-
id="CrazyflowEnvTargetVelocity-v0",
10-
vector_entry_point="crazyflow.gymnasium_envs:CrazyflowEnvTargetVelocity",
13+
id="DroneReachVel-v0",
14+
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvTargetVelocity",
1115
)

crazyflow/gymnasium_envs/crazyflow.py

+27-26
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import math
22
import warnings
33
from functools import partial
4-
from typing import Dict, Literal, Optional, Tuple
4+
from typing import Literal
55

66
import jax
77
import jax.numpy as jnp
8-
import numpy as np
98
from flax.struct import dataclass
109
from gymnasium import spaces
1110
from gymnasium.vector import VectorEnv
1211
from gymnasium.vector.utils import batch_space
1312
from jax import Array
13+
from numpy.typing import NDArray
1414

1515
from crazyflow.control.controller import MAX_THRUST, MIN_THRUST, Control
1616
from crazyflow.sim.core import Sim
@@ -19,8 +19,8 @@
1919

2020
@dataclass
2121
class RescaleParams:
22-
scale_factor: jnp.ndarray
23-
mean: jnp.ndarray
22+
scale_factor: Array
23+
mean: Array
2424

2525

2626
CONTROL_RESCALE_PARAMS = {
@@ -35,6 +35,12 @@ class RescaleParams:
3535
}
3636

3737

38+
@partial(jax.jit, static_argnames=["convert"])
39+
def maybe_to_numpy(data: Array, convert: bool) -> NDArray | Array:
40+
"""Converts data to numpy array if convert is True."""
41+
return jax.lax.cond(convert, lambda: jax.device_get(data), lambda: data)
42+
43+
3844
class CrazyflowBaseEnv(VectorEnv):
3945
"""JAX Gymnasium environment for Crazyflie simulation."""
4046

@@ -103,7 +109,7 @@ def __init__(
103109
)
104110
self.observation_space = batch_space(self.single_observation_space, self.sim.n_worlds)
105111

106-
def step(self, action: Array) -> Tuple[Array, Array, Array, Array, Dict]:
112+
def step(self, action: Array) -> tuple[Array, Array, Array, Array, dict]:
107113
assert self.action_space.contains(action), f"{action!r} ({type(action)}) invalid"
108114
action = jnp.array(action, device=self.device).reshape(
109115
(self.sim.n_worlds, self.sim.n_drones, -1)
@@ -139,8 +145,8 @@ def step(self, action: Array) -> Tuple[Array, Array, Array, Array, Dict]:
139145
return (
140146
self._get_obs(),
141147
reward,
142-
self._maybe_to_numpy(terminated),
143-
self._maybe_to_numpy(truncated),
148+
maybe_to_numpy(terminated, self.return_datatype == "numpy"),
149+
maybe_to_numpy(truncated, self.return_datatype == "numpy"),
144150
{},
145151
)
146152

@@ -165,7 +171,7 @@ def _rescale_action(action: Array, control_type: str) -> Array:
165171
return action * params.scale_factor + params.mean
166172

167173
def reset_all(
168-
self, *, seed: Optional[int] = None, options: Optional[dict] = None
174+
self, *, seed: int | None = None, options: dict | None = None
169175
) -> tuple[dict[str, Array], dict]:
170176
super().reset(seed=seed)
171177

@@ -226,42 +232,37 @@ def _reward() -> None:
226232

227233
@staticmethod
228234
@jax.jit
229-
def _terminated(dones: jax.Array, states: SimState, contacts: jax.Array) -> jnp.ndarray:
235+
def _terminated(dones: Array, states: SimState, contacts: Array) -> Array:
230236
contact = jnp.any(contacts, axis=1)
231237
z_coords = states.pos[..., 2]
232-
below_ground = jnp.any(
233-
z_coords < -0.1, axis=1
234-
) # Should not be triggered due to collision checking
238+
# Sanity check if we are below the ground. Should not be triggered due to collision checking
239+
below_ground = jnp.any(z_coords < -0.1, axis=1)
235240
terminated = jnp.logical_or(below_ground, contact) # no termination condition
236241
return jnp.where(dones, False, terminated)
237242

238243
@staticmethod
239244
@jax.jit
240245
def _truncated(
241-
dones: jax.Array, steps: jax.Array, max_episode_steps: jax.Array, n_substeps: jax.Array
242-
) -> jnp.ndarray:
246+
dones: Array, steps: Array, max_episode_steps: Array, n_substeps: Array
247+
) -> Array:
243248
truncated = steps / n_substeps >= max_episode_steps
244249
return jnp.where(dones, False, truncated)
245250

246251
def render(self):
247252
self.sim.render()
248253

249-
def _get_obs(self) -> Dict[str, jnp.ndarray]:
254+
def _get_obs(self) -> dict[str, Array]:
250255
obs = {
251-
state: self._maybe_to_numpy(
256+
state: maybe_to_numpy(
252257
getattr(self.sim.states, state)[..., 2]
253258
if state == "pos"
254-
else getattr(self.sim.states, state)
259+
else getattr(self.sim.states, state),
260+
self.return_datatype == "numpy",
255261
)
256262
for state in self.states_to_include_in_obs
257263
}
258264
return obs
259265

260-
def _maybe_to_numpy(self, data: Array) -> np.ndarray:
261-
if self.return_datatype == "numpy" and not isinstance(data, np.ndarray):
262-
return jax.device_get(data)
263-
return data
264-
265266

266267
class CrazyflowEnvReachGoal(CrazyflowBaseEnv):
267268
"""JAX Gymnasium environment for Crazyflie simulation."""
@@ -284,7 +285,7 @@ def reward(self) -> Array:
284285

285286
@staticmethod
286287
@jax.jit
287-
def _reward(terminated: jax.Array, states: SimState, goal: jax.Array) -> jnp.ndarray:
288+
def _reward(terminated: Array, states: SimState, goal: Array) -> Array:
288289
norm_distance = jnp.linalg.norm(states.pos - goal, axis=2)
289290
reward = jnp.exp(-2.0 * norm_distance)
290291
return jnp.where(terminated, -1.0, reward)
@@ -302,7 +303,7 @@ def reset(self, mask: Array) -> None:
302303
)
303304
self.goal = self.goal.at[mask].set(new_goals[mask])
304305

305-
def _get_obs(self) -> Dict[str, jnp.ndarray]:
306+
def _get_obs(self) -> dict[str, Array]:
306307
obs = super()._get_obs()
307308
obs["difference_to_goal"] = [self.goal - self.sim.states.pos]
308309
return obs
@@ -329,7 +330,7 @@ def reward(self) -> Array:
329330

330331
@staticmethod
331332
@jax.jit
332-
def _reward(terminated: jax.Array, states: SimState, target_vel: jax.Array) -> jnp.ndarray:
333+
def _reward(terminated: Array, states: SimState, target_vel: Array) -> Array:
333334
norm_distance = jnp.linalg.norm(states.vel - target_vel, axis=2)
334335
reward = jnp.exp(-norm_distance)
335336
return jnp.where(terminated, -1.0, reward)
@@ -347,7 +348,7 @@ def reset(self, mask: Array) -> None:
347348
)
348349
self.target_vel = self.target_vel.at[mask].set(new_target_vel[mask])
349350

350-
def _get_obs(self) -> Dict[str, jnp.ndarray]:
351+
def _get_obs(self) -> dict[str, Array]:
351352
obs = super()._get_obs()
352353
obs["difference_to_target_vel"] = [self.target_vel - self.sim.states.vel]
353354
return obs

0 commit comments

Comments
 (0)