Skip to content

Commit 3359bc4

Browse files
author
Lui
committed
Merge branch 'attitude_interface' into notebooks
merge attitude interface into notebooks
2 parents 2cb8c0a + 289dfc3 commit 3359bc4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2417
-2770
lines changed

benchmark/main.py

+15-20
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ml_collections import config_dict
88

99
import crazyflow # noqa: F401, ensure gymnasium envs are registered
10-
from crazyflow.sim.core import Sim
10+
from crazyflow.sim import Sim
1111

1212

1313
def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float) -> None:
@@ -19,8 +19,8 @@ def analyze_timings(times: list[float], n_steps: int, n_worlds: int, freq: float
1919
tmax, idx_tmax = np.max(times), np.argmax(times)
2020

2121
# Check for significant variance
22-
if tmax / tmin > 5:
23-
print("Warning: step time varies by more than 5x. Is JIT compiling during the benchmark?")
22+
if tmax / tmin > 10:
23+
print("Warning: step time varies by more than 10x. Is JIT compiling during the benchmark?")
2424
print(f"Times: max {tmax:.2e} @ {idx_tmax}, min {tmin:.2e} @ {idx_tmin}")
2525

2626
# Performance metrics
@@ -43,28 +43,23 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
4343
device = jax.devices(device)[0]
4444

4545
envs = gymnasium.make_vec(
46-
"DroneReachPos-v0",
47-
time_horizon_in_seconds=2,
48-
return_datatype="numpy",
49-
num_envs=sim_config.n_worlds,
50-
**sim_config,
46+
"DroneReachPos-v0", time_horizon_in_seconds=3, num_envs=sim_config.n_worlds, **sim_config
5147
)
5248

5349
# Action for going up (in attitude control)
5450
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
5551
action[..., 0] = 0.3
5652
# Step through env once to ensure JIT compilation
57-
envs.reset_all(seed=42)
58-
envs.step(action)
53+
envs.reset(seed=42)
5954
envs.step(action)
6055

61-
jax.block_until_ready(envs.unwrapped.sim.states.pos) # Ensure JIT compiled dynamics
56+
jax.block_until_ready(envs.unwrapped.sim.data) # Ensure JIT compiled dynamics
6257

6358
# Step through the environment
6459
for _ in range(n_steps):
6560
tstart = time.perf_counter()
6661
envs.step(action)
67-
jax.block_until_ready(envs.unwrapped.sim.states.pos)
62+
jax.block_until_ready(envs.unwrapped.sim.data)
6863
times.append(time.perf_counter() - tstart)
6964

7065
envs.close()
@@ -83,14 +78,14 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
8378

8479
sim.reset()
8580
sim.attitude_control(cmd)
86-
sim.step()
87-
jax.block_until_ready(sim.states.pos) # Ensure JIT compiled dynamics
81+
sim.step(sim.freq // sim.control_freq)
82+
jax.block_until_ready(sim.data) # Ensure JIT compiled dynamics
8883

8984
for _ in range(n_steps):
9085
tstart = time.perf_counter()
9186
sim.attitude_control(cmd)
92-
sim.step()
93-
jax.block_until_ready(sim.states.pos)
87+
sim.step(sim.freq // sim.control_freq)
88+
jax.block_until_ready(sim.data)
9489
times.append(time.perf_counter() - tstart)
9590

9691
analyze_timings(times, n_steps, sim.n_worlds, sim.freq)
@@ -102,16 +97,16 @@ def main():
10297
sim_config = config_dict.ConfigDict()
10398
sim_config.n_worlds = 1
10499
sim_config.n_drones = 1
105-
sim_config.physics = "sys_id"
100+
sim_config.physics = "analytical"
106101
sim_config.control = "attitude"
107-
sim_config.controller = "emulatefirmware"
102+
sim_config.attitude_freq = 500
108103
sim_config.device = device
109104

110105
print("Simulator performance")
111-
profile_step(sim_config, 100, device)
106+
profile_step(sim_config, 1000, device)
112107

113108
print("\nGymnasium environment performance")
114-
profile_gym_env_step(sim_config, 100, device)
109+
profile_gym_env_step(sim_config, 1000, device)
115110

116111

117112
if __name__ == "__main__":

benchmark/performance.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pyinstrument.renderers.html import HTMLRenderer
1111

1212
import crazyflow # noqa: F401, ensure gymnasium envs are registered
13-
from crazyflow.sim.core import Sim
13+
from crazyflow.sim import Sim
1414

1515
if TYPE_CHECKING:
1616
from crazyflow.gymnasium_envs import CrazyflowEnvReachGoal
@@ -26,9 +26,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
2626
sim.reset()
2727
control_fn(cmd)
2828
sim.step()
29-
sim.step()
30-
sim.reset()
31-
jax.block_until_ready(sim.states.pos)
29+
jax.block_until_ready(sim.data)
3230

3331
profiler = Profiler()
3432
profiler.start()
@@ -37,7 +35,7 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
3735
control_fn(cmd)
3836
# sim.reset()
3937
sim.step()
40-
jax.block_until_ready(sim.states.pos)
38+
jax.block_until_ready(sim.data)
4139
profiler.stop()
4240
renderer = HTMLRenderer()
4341
renderer.open_in_browser(profiler.last_session)
@@ -47,29 +45,26 @@ def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, devic
4745
device = jax.devices(device)[0]
4846

4947
envs: CrazyflowEnvReachGoal = gymnasium.make_vec(
50-
"DroneReachPos-v0",
51-
time_horizon_in_seconds=2,
52-
return_datatype="numpy",
53-
num_envs=sim_config.n_worlds,
54-
**sim_config,
48+
"DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
5549
)
5650

5751
# Action for going up (in attitude control)
5852
action = np.zeros((sim_config.n_worlds, 4), dtype=np.float32)
59-
action[..., 0] = -0.3
53+
action[..., 0] = 0.3
6054

6155
# Step through env once to ensure JIT compilation.
62-
envs.reset_all(seed=42)
56+
envs.reset(seed=42)
6357
envs.step(action)
6458
envs.step(action) # Ensure all paths have been taken at least once
65-
envs.reset_all(seed=42)
59+
envs.reset(seed=42)
60+
jax.block_until_ready(envs.unwrapped.sim.data)
6661

6762
profiler = Profiler()
6863
profiler.start()
6964

7065
for _ in range(n_steps):
7166
envs.step(action)
72-
jax.block_until_ready(envs.unwrapped.sim.states.pos)
67+
jax.block_until_ready(envs.unwrapped.sim.data)
7368

7469
profiler.stop()
7570
renderer = HTMLRenderer()
@@ -84,7 +79,6 @@ def main():
8479
sim_config.n_drones = 1
8580
sim_config.physics = "analytical"
8681
sim_config.control = "attitude"
87-
sim_config.controller = "emulatefirmware"
8882
sim_config.device = device
8983

9084
profile_step(sim_config, 1000, device)

crazyflow/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
11
import crazyflow.gymnasium_envs # noqa: F401, ensure gymnasium envs are registered
2+
from crazyflow.control import Control
3+
from crazyflow.sim import Physics, Sim
4+
5+
__all__ = ["Sim", "Physics", "Control"]

crazyflow/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
GRAVITY: float = 9.81
66

77
# Drone constants
8-
ARM_LEN: float = 0.46
8+
ARM_LEN: float = 0.0325 * jnp.sqrt(2)
99
MIX_MATRIX: Array = jnp.array([[-0.5, -0.5, -1], [-0.5, 0.5, 1], [0.5, 0.5, -1], [0.5, -0.5, 1]])
1010
SIGN_MIX_MATRIX: Array = jnp.sign(MIX_MATRIX)
1111
MASS: float = 0.027

crazyflow/control/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from crazyflow.control.control import Control
2+
3+
__all__ = ["Control"]

crazyflow/control/controller.py crazyflow/control/control.py

+27-14
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,28 @@ class Control(str, Enum):
2424
"""Control type of the simulated onboard controller."""
2525

2626
state = "state"
27+
"""State control takes [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].
28+
29+
Note:
30+
Recommended frequency is >=20 Hz.
31+
32+
Warning:
33+
Currently, we only use positions, velocities, and yaw. The rest of the state is ignored.
34+
This is subject to change in the future.
35+
"""
2736
attitude = "attitude"
28-
thrust = "thrust"
29-
default = attitude
37+
"""Attitude control takes [collective thrust, roll, pitch, yaw].
3038
39+
Note:
40+
Recommended frequency is >=100 Hz.
41+
"""
42+
thrust = "thrust"
43+
"""Thrust control takes [thrust1, thrust2, thrust3, thrust4] for each drone motor.
3144
32-
class Controller(str, Enum):
33-
"""Controller type of the simulated onboard controller."""
34-
35-
pycffirmware = "pycffirmware"
36-
emulatefirmware = "emulatefirmware"
37-
default = emulatefirmware
45+
Note:
46+
Recommended frequency is >=500 Hz.
47+
"""
48+
default = attitude
3849

3950

4051
KF: float = 3.16e-10
@@ -90,27 +101,28 @@ def state2attitude(
90101

91102
@partial(jnp.vectorize, signature="(4),(4),(3),(3)->(4),(3)", excluded=[4])
92103
def attitude2rpm(
93-
cmd: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
104+
controls: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
94105
) -> tuple[Array, Array]:
95-
"""Convert the desired attitude and quaternion into motor RPMs."""
106+
"""Convert the desired collective thrust and attitude into motor RPMs."""
96107
rot = R.from_quat(quat)
97-
target_rot = R.from_euler("xyz", cmd[..., 1:])
108+
target_rot = R.from_euler("xyz", controls[1:])
98109
drot = (target_rot.inv() * rot).as_matrix()
99-
100110
# Extract the anti-symmetric part of the relative rotation matrix.
101111
rot_e = jnp.array([drot[2, 1] - drot[1, 2], drot[0, 2] - drot[2, 0], drot[1, 0] - drot[0, 1]])
102-
rpy_rates_e = -(rot.as_euler("xyz") - last_rpy) / dt # Assuming zero rpy_rates target
112+
# TODO: Assumes zero rpy_rates targets for now, use the actual target instead.
113+
rpy_rates_e = -(rot.as_euler("xyz") - last_rpy) / dt
103114
rpy_err_i = rpy_err_i - rot_e * dt
104115
rpy_err_i = jnp.clip(rpy_err_i, -1500.0, 1500.0)
105116
rpy_err_i = rpy_err_i.at[:2].set(jnp.clip(rpy_err_i[:2], -1.0, 1.0))
106117
# PID target torques.
107118
target_torques = -P_T * rot_e + D_T * rpy_rates_e + I_T * rpy_err_i
108119
target_torques = jnp.clip(target_torques, -3200, 3200)
109-
thrust_per_motor = cmd[0] / 4
120+
thrust_per_motor = jnp.atleast_1d(controls[0]) / 4
110121
pwm = jnp.clip(thrust2pwm(thrust_per_motor) + MIX_MATRIX @ target_torques, MIN_PWM, MAX_PWM)
111122
return pwm2rpm(pwm), rpy_err_i
112123

113124

125+
@partial(jnp.vectorize, signature="(4)->(4)")
114126
def thrust2pwm(thrust: Array) -> Array:
115127
"""Convert the desired thrust into motor PWM.
116128
@@ -124,6 +136,7 @@ def thrust2pwm(thrust: Array) -> Array:
124136
return jnp.clip((jnp.sqrt(thrust / KF) - PWM2RPM_CONST) / PWM2RPM_SCALE, MIN_PWM, MAX_PWM)
125137

126138

139+
@partial(jnp.vectorize, signature="(4)->(4)")
127140
def pwm2rpm(pwm: Array) -> Array:
128141
"""Convert the motors' PWMs into RPMs."""
129142
return PWM2RPM_CONST + PWM2RPM_SCALE * pwm

crazyflow/gymnasium_envs/__init__.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
from gymnasium.envs.registration import register
22

33
from crazyflow.gymnasium_envs.crazyflow import (
4+
CrazyflowEnvFigureEightTrajectory,
45
CrazyflowEnvLanding,
56
CrazyflowEnvReachGoal,
67
CrazyflowEnvTargetVelocity,
7-
CrazyflowRL
8+
CrazyflowRL,
89
)
910

10-
__all__ = ["CrazyflowEnvReachGoal", "CrazyflowEnvTargetVelocity", "CrazyflowEnvLanding", "CrazyflowRL"]
11+
__all__ = [
12+
"CrazyflowEnvReachGoal",
13+
"CrazyflowEnvTargetVelocity",
14+
"CrazyflowEnvLanding",
15+
"CrazyflowRL",
16+
"CrazyflowEnvFigureEightTrajectory",
17+
]
1118

1219
register(
1320
id="DroneReachPos-v0",
@@ -19,8 +26,12 @@
1926
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvTargetVelocity",
2027
)
2128

22-
2329
register(
2430
id="DroneLanding-v0",
2531
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvLanding",
2632
)
33+
34+
register(
35+
id="DroneFigureEightTrajectory-v0",
36+
vector_entry_point="crazyflow.gymnasium_envs.crazyflow:CrazyflowEnvFigureEightTrajectory",
37+
)

0 commit comments

Comments
 (0)