Skip to content

Commit 6a7b856

Browse files
committed
Clean up formatting
1 parent 97bde18 commit 6a7b856

File tree

8 files changed

+63
-73
lines changed

8 files changed

+63
-73
lines changed

benchmark/main.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,36 @@
11
import time
2-
from ml_collections import config_dict
32

3+
import gymnasium
44
import jax
55
import jax.numpy as jnp
66
import numpy as np
7-
import gymnasium
7+
from ml_collections import config_dict
8+
9+
import crazyflow # noqa: F401, ensure gymnasium envs are registered
810
from crazyflow.sim.core import Sim
9-
import crazyflow.gymnasium_envs
1011

1112

1213
def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float) -> None:
1314
"""Analyze timing results and print performance metrics."""
1415
if not times:
1516
raise ValueError("The list of timing results is empty.")
16-
17+
1718
tmin, idx_tmin = np.min(times), np.argmin(times)
1819
tmax, idx_tmax = np.max(times), np.argmax(times)
19-
20+
2021
# Check for significant variance
2122
if tmax / tmin > 5:
2223
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
2324
print(f"Times: max {tmax:.2e}@{idx_tmax}, min {tmin:.2e}@{idx_tmin}")
24-
25+
2526
# Performance metrics
2627
n_frames = n_steps * n_worlds # Number of frames simulated
2728
total_time = np.sum(times)
2829
avg_step_time = np.mean(times)
2930
step_time_std = np.std(times)
3031
fps = n_frames / total_time
3132
real_time_factor = (n_steps / freq) * n_worlds / total_time
32-
33+
3334
print(
3435
f"Avg step time: {avg_step_time:.2e}s, std: {step_time_std:.2e}"
3536
f"\nFPS: {fps:.3e}, Real time factor: {real_time_factor:.2e}"
@@ -62,7 +63,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
6263
_, _ = envs.reset_all(seed=42)
6364
_, _, _, _, _ = envs.step(action)
6465
_, _ = envs.reset_all(seed=42)
65-
66+
6667
jax.block_until_ready(envs.unwrapped.sim._mjx_data) # Ensure JIT compiled dynamics
6768

6869
# Step through the environment
@@ -85,7 +86,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
8586

8687
cmd = jnp.zeros((sim.n_worlds, sim.n_drones, 4), device=device)
8788
cmd = cmd.at[0, 0, 0].set(1)
88-
89+
8990
sim.reset()
9091
sim.attitude_control(cmd)
9192
sim.step()

benchmark/performance.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1+
import gymnasium
12
import jax
2-
import jax.numpy as jnp
33
import numpy as np
4-
import gymnasium
54
from ml_collections import config_dict
65
from pyinstrument import Profiler
76
from pyinstrument.renderers.html import HTMLRenderer
87

8+
import crazyflow # noqa: F401, ensure gymnasium envs are registered
99
from crazyflow.sim.core import Sim
10-
import crazyflow.gymnasium_envs
1110

1211

1312
def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
14-
sim = Sim(
15-
**sim_config
16-
)
13+
sim = Sim(**sim_config)
1714
device = jax.devices(device)[0]
1815
ndim = 13 if sim.control == "state" else 4
1916
control_fn = sim.state_control if sim.control == "state" else sim.attitude_control
@@ -39,6 +36,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
3936
renderer = HTMLRenderer()
4037
renderer.open_in_browser(profiler.last_session)
4138

39+
4240
def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
4341
device = jax.devices(device)[0]
4442

@@ -65,7 +63,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
6563
_, _ = envs.reset_all(seed=42)
6664
_, _, _, _, _ = envs.step(action)
6765
_, _ = envs.reset_all(seed=42)
68-
66+
6967
jax.block_until_ready(envs.unwrapped.sim._mjx_data) # Ensure JIT compiled dynamics
7068

7169
profiler = Profiler()
@@ -79,6 +77,7 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
7977
renderer.open_in_browser(profiler.last_session)
8078
envs.close()
8179

80+
8281
def main():
8382
device = "cpu"
8483
sim_config = config_dict.ConfigDict()
@@ -104,6 +103,5 @@ def main():
104103
profile_gym_env_step(sim_config, 1000, device)
105104

106105

107-
108106
if __name__ == "__main__":
109107
main()

crazyflow/gymnasium_envs/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from crazyflow.gymnasium_envs.crazyflow import CrazyflowEnvReachGoal, CrazyflowEnvTargetVelocity
2-
3-
41
from gymnasium.envs.registration import register
52

63
register(

crazyflow/gymnasium_envs/crazyflow.py

+31-30
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,38 @@
11
import math
22
import warnings
3-
from dataclasses import fields
43
from functools import partial
54
from typing import Dict, Literal, Optional, Tuple
65

76
import jax
87
import jax.numpy as jnp
98
import numpy as np
9+
from flax.struct import dataclass
1010
from gymnasium import spaces
1111
from gymnasium.vector import VectorEnv
1212
from gymnasium.vector.utils import batch_space
1313
from jax import Array
1414

15-
from crazyflow.control.controller import Control, MAX_THRUST, MIN_THRUST
15+
from crazyflow.control.controller import MAX_THRUST, MIN_THRUST, Control
1616
from crazyflow.sim.core import Sim
1717
from crazyflow.sim.structs import SimState
18-
from flax.struct import dataclass
18+
1919

2020
@dataclass
2121
class RescaleParams:
2222
scale_factor: jnp.ndarray
2323
mean: jnp.ndarray
2424

25+
2526
CONTROL_RESCALE_PARAMS = {
26-
"state": None,
27-
"thrust": None,
28-
"attitude": RescaleParams(
29-
scale_factor=jnp.array([
30-
4 * (MAX_THRUST - MIN_THRUST) / 2,
31-
jnp.pi / 6,
32-
jnp.pi / 6,
33-
jnp.pi / 6
34-
]),
35-
mean=jnp.array([
36-
4 * (MIN_THRUST + MAX_THRUST) / 2,
37-
0.0,
38-
0.0,
39-
0.0
40-
])
27+
"state": None,
28+
"thrust": None,
29+
"attitude": RescaleParams(
30+
scale_factor=jnp.array(
31+
[4 * (MAX_THRUST - MIN_THRUST) / 2, jnp.pi / 6, jnp.pi / 6, jnp.pi / 6]
4132
),
42-
}
33+
mean=jnp.array([4 * (MIN_THRUST + MAX_THRUST) / 2, 0.0, 0.0, 0.0]),
34+
),
35+
}
4336

4437

4538
class CrazyflowBaseEnv(VectorEnv):
@@ -48,7 +41,7 @@ class CrazyflowBaseEnv(VectorEnv):
4841
def __init__(
4942
self,
5043
*,
51-
jax_random_key, # required for jax random number generator
44+
jax_random_key: int, # required for jax random number generator
5245
num_envs: int = 1, # required for VectorEnv
5346
max_episode_steps: int = 1000,
5447
return_datatype: Literal["numpy", "jax"] = "jax",
@@ -57,9 +50,13 @@ def __init__(
5750
"""Summary: Initializes the CrazyflowEnv.
5851
5952
Args:
60-
max_episode_steps (int): The maximum number of steps per episode.
61-
return_datatype (Literal["numpy", "jax"]): The data type for returned arrays, either "numpy" or "jax". If specified as "numpy", the returned arrays will be numpy arrays on the CPU. If specified as "jax", the returned arrays will be jax arrays on the "device" specifiedf or the simulation.
62-
**kwargs: Takes arguments that are passed to the Crazyfly simulation .
53+
jax_random_key: The random key for the jax random number generator.
54+
num_envs: The number of environments to run in parallel.
55+
max_episode_steps: The maximum number of steps per episode.
56+
return_datatype: The data type for returned arrays, either "numpy" or "jax". If "numpy",
57+
the returned arrays will be numpy arrays on the CPU. If "jax", the returned arrays
58+
will be jax arrays on the "device" specified for the simulation.
59+
**kwargs: Takes arguments that are passed to the Crazyfly simulation.
6360
"""
6461
assert num_envs == kwargs["n_worlds"], "num_envs must be equal to n_worlds"
6562

@@ -167,7 +164,9 @@ def _rescale_action(action: Array, control_type: str) -> Array:
167164

168165
return action * params.scale_factor + params.mean
169166

170-
def reset_all(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
167+
def reset_all(
168+
self, *, seed: Optional[int] = None, options: Optional[dict] = None
169+
) -> tuple[dict[str, Array], dict]:
171170
super().reset(seed=seed)
172171

173172
# Resets ALL (!) environments
@@ -209,15 +208,15 @@ def reset(self, mask: Array) -> None:
209208
)
210209

211210
@property
212-
def reward(self):
211+
def reward(self) -> Array:
213212
return self._reward(self.terminated, self.sim.states)
214213

215214
@property
216-
def terminated(self):
215+
def terminated(self) -> Array:
217216
return self._terminated(self.prev_done, self.sim.states, self.sim.contacts())
218217

219218
@property
220-
def truncated(self):
219+
def truncated(self) -> Array:
221220
return self._truncated(
222221
self.prev_done, self.sim.steps, self.max_episode_steps, self.n_substeps
223222
)
@@ -250,7 +249,9 @@ def render(self):
250249
def _get_obs(self) -> Dict[str, jnp.ndarray]:
251250
obs = {
252251
state: self._maybe_to_numpy(
253-
getattr(self.sim.states, state)[..., 2] if state == "pos" else getattr(self.sim.states, state)
252+
getattr(self.sim.states, state)[..., 2]
253+
if state == "pos"
254+
else getattr(self.sim.states, state)
254255
)
255256
for state in self.states_to_include_in_obs
256257
}
@@ -278,7 +279,7 @@ def __init__(self, **kwargs: dict):
278279
self.goal = jnp.zeros((kwargs["n_worlds"], 3), dtype=jnp.float32)
279280

280281
@property
281-
def reward(self):
282+
def reward(self) -> Array:
282283
return self._reward(self.terminated, self.sim.states, self.goal)
283284

284285
@staticmethod
@@ -323,7 +324,7 @@ def __init__(self, **kwargs: dict):
323324
self.target_vel = jnp.zeros((kwargs["n_worlds"], 3), dtype=jnp.float32)
324325

325326
@property
326-
def reward(self):
327+
def reward(self) -> Array:
327328
return self._reward(self.terminated, self.sim.states, self.target_vel)
328329

329330
@staticmethod

examples/gradient.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
def main():
1212
sim = Sim(control=Control.state)
1313

14-
def step(cmd: NDArray):
14+
def step(cmd: NDArray) -> jax.Array:
1515
sim.reset()
1616
sim.state_control(cmd)
1717
sim.step()

examples/gymnasium_env.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
from ml_collections import config_dict
44

5-
import crazyflow.gymnasium_envs
65
from crazyflow.control.controller import Control, Controller
76
from crazyflow.sim.physics import Physics
87

@@ -13,23 +12,24 @@
1312
sim_config.control = Control.default
1413
sim_config.controller = Controller.default
1514
sim_config.control_freq = 50
16-
sim_config.n_drones=1
17-
sim_config.n_worlds=20
15+
sim_config.n_drones = 1
16+
sim_config.n_worlds = 20
1817

19-
SEED=42
18+
SEED = 42
2019

2120
envs = gymnasium.make_vec(
2221
"CrazyflowEnvReachGoal-v0",
2322
max_episode_steps=1000,
2423
return_datatype="numpy",
25-
num_envs=sim_config.n_worlds,
24+
num_envs=sim_config.n_worlds,
2625
jax_random_key=SEED,
2726
**sim_config,
2827
)
2928

3029
# action for going up (in attitude control). NOTE actions are rescaled in the environment
3130
action = np.array(
32-
[[[-0.2, 0, 0, 0] for _ in range(sim_config.n_drones)] for _ in range(sim_config.n_worlds)], dtype=np.float32
31+
[[[-0.2, 0, 0, 0] for _ in range(sim_config.n_drones)] for _ in range(sim_config.n_worlds)],
32+
dtype=np.float32,
3333
).reshape(sim_config.n_worlds, -1)
3434

3535
obs, info = envs.reset_all(seed=SEED)
@@ -38,5 +38,5 @@
3838
for _ in range(1500):
3939
observation, reward, terminated, truncated, info = envs.step(action)
4040
envs.render()
41-
41+
4242
envs.close()

examples/spiral.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
11
import numpy as np
22

3-
from crazyflow.control.controller import Control, Controller
3+
from crazyflow.control.controller import Control
44
from crazyflow.sim.core import Sim
5-
from crazyflow.sim.physics import Physics
65

76

87
def main():
9-
sim = Sim(
10-
n_worlds=1,
11-
n_drones=4,
12-
physics=Physics.analytical,
13-
control=Control.state,
14-
controller=Controller.emulatefirmware,
15-
freq=500,
16-
control_freq=500,
17-
device="cpu",
18-
)
8+
sim = Sim(n_worlds=1, n_drones=4, control=Control.state, device="cpu")
199
sim.reset()
2010
duration = 5.0
2111
fps = 60

pyproject.toml

+6-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
]
3030

3131
[project.optional-dependencies]
32-
test = ["pytest>=8.0.0"]
32+
test = ["pytest>=8.0.0", "pytest-cov"]
3333
gpu = ["jax[cuda12]"]
3434

3535
[tool.setuptools.packages]
@@ -69,13 +69,16 @@ target-version = "py38"
6969

7070
[tool.ruff.lint]
7171
select = ["E4", "E7", "E9", "F", "I", "D", "TCH", "ANN"]
72-
ignore = ["ANN101", "ANN401"]
72+
ignore = ["ANN401"]
7373
fixable = ["ALL"]
7474
unfixable = []
7575

7676
[tool.ruff.lint.per-file-ignores]
77-
"benchmarks/*" = ["D100", "D103"]
77+
"benchmark/*" = ["D100", "D103"]
7878
"tests/*" = ["D100", "D103", "D104"]
79+
"examples/*" = ["D100", "D103"]
80+
# TODO: Remove once everything is stable and document
81+
"crazyflow/*" = ["D100", "D101", "D102", "D104", "D107"]
7982

8083

8184
[tool.ruff.lint.pydocstyle]

0 commit comments

Comments
 (0)