Skip to content

Commit 2d0f4af

Browse files
committed
Merge branch 'dev'
2 parents 113e229 + a36d70b commit 2d0f4af

14 files changed

+503
-1407
lines changed

crazyflow/sim/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from crazyflow.sim.physics import Physics
22
from crazyflow.sim.sim import Sim
3-
from crazyflow.sim.symbolic import symbolic
3+
from crazyflow.sim.symbolic import symbolic_attitude, symbolic_from_sim, symbolic_thrust
44

5-
__all__ = ["Sim", "Physics", "symbolic"]
5+
__all__ = ["Sim", "Physics", "symbolic_attitude", "symbolic_from_sim", "symbolic_thrust"]

crazyflow/sim/physics.py

+17-23
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,10 @@
1010
from crazyflow.control.control import KF, KM
1111

1212
SYS_ID_PARAMS = {
13-
"acc_k1": 20.91,
14-
"acc_k2": 3.65,
15-
"roll_alpha": -3.96,
16-
"roll_beta": 4.08,
17-
"pitch_alpha": -6.00,
18-
"pitch_beta": 6.21,
19-
"yaw_alpha": 0.00,
20-
"yaw_beta": 0.00,
13+
"acc": jnp.array([20.907574256269616, 3.653687545690674]),
14+
"roll_acc": jnp.array([-130.3, -16.33, 119.3]),
15+
"pitch_acc": jnp.array([-99.94, -13.3, 84.73]),
16+
"yaw_acc": jnp.array([0.0, 0.0, 0.0]),
2117
}
2218

2319

@@ -60,27 +56,24 @@ def surrogate_identified_collective_wrench(
6056
rot = R.from_quat(quat)
6157
thrust = rot.apply(jnp.array([0, 0, collective_thrust]))
6258
drift = rot.apply(jnp.array([0, 0, 1]))
63-
prev_rpy_rates = rot.apply(rpy_rates, inverse=True)
64-
a1, a2 = SYS_ID_PARAMS["acc_k1"], SYS_ID_PARAMS["acc_k2"]
65-
acc = thrust * a1 + drift * a2
66-
# rpy_rates_deriv have no real meaning in this context, since the identified dynamics set the
67-
# rpy_rates to the commanded values directly. However, since we use a unified integration
68-
# interface for all physics models, we cannot access states directly. Instead, we calculate
69-
# which rpy_rates_deriv would have resulted in the desired rpy_rates, and return that.
70-
roll_cmd, pitch_cmd, yaw_cmd = attitude
59+
rpy_rates_local = rot.apply(rpy_rates, inverse=True)
60+
k1, k2 = SYS_ID_PARAMS["acc"]
61+
acc = thrust * k1 + drift * k2
7162
rpy = rot.as_euler("xyz")
72-
roll_rate = SYS_ID_PARAMS["roll_alpha"] * rpy[0] + SYS_ID_PARAMS["roll_beta"] * roll_cmd
73-
pitch_rate = SYS_ID_PARAMS["pitch_alpha"] * rpy[1] + SYS_ID_PARAMS["pitch_beta"] * pitch_cmd
74-
yaw_rate = SYS_ID_PARAMS["yaw_alpha"] * rpy[2] + SYS_ID_PARAMS["yaw_beta"] * yaw_cmd
75-
rpy_rates_local = jnp.array([roll_rate, pitch_rate, yaw_rate])
76-
rpy_rates_local_deriv = (rpy_rates_local - prev_rpy_rates) / dt
63+
k1, k2, k3 = SYS_ID_PARAMS["roll_acc"]
64+
roll_rate_deriv = k1 * rpy[0] + k2 * rpy_rates_local[0] + k3 * attitude[0]
65+
k1, k2, k3 = SYS_ID_PARAMS["pitch_acc"]
66+
pitch_rate_deriv = k1 * rpy[1] + k2 * rpy_rates_local[1] + k3 * attitude[1]
67+
k1, k2, k3 = SYS_ID_PARAMS["yaw_acc"]
68+
yaw_rate_deriv = k1 * rpy[2] + k2 * rpy_rates_local[2] + k3 * attitude[2]
69+
rpy_rates_deriv = jnp.array([roll_rate_deriv, pitch_rate_deriv, yaw_rate_deriv])
7770
# The identified dynamics model does not use forces or torques, because we assume no knowledge
7871
# of the drone's mass and inertia. However, to remain compatible with the physics pipeline, we
7972
# return surrogate forces and torques that result in the desired acceleration and rpy rates
8073
# derivative. When converting back to the state derivative, the mass and inertia will cancel
8174
# out, resulting in the correct acceleration and rpy rates derivative regardless of the model's
8275
# mass and inertia.
83-
surrogate_torques = rot.apply(J @ rpy_rates_local_deriv)
76+
surrogate_torques = rot.apply(J @ rpy_rates_deriv)
8477
surrogate_forces = acc * mass
8578
return surrogate_forces, surrogate_torques
8679

@@ -94,7 +87,8 @@ def collective_force2acceleration(force: Array, mass: Array) -> Array:
9487
@partial(vectorize, signature="(3),(4),(3,3)->(3)")
9588
def collective_torque2rpy_rates_deriv(torque: Array, quat: Array, J_INV: Array) -> Array:
9689
"""Convert torques to rpy_rates_deriv."""
97-
return R.from_quat(quat).apply(J_INV @ torque)
90+
rot = R.from_quat(quat)
91+
return rot.apply(J_INV @ rot.apply(torque, inverse=True))
9892

9993

10094
@partial(vectorize, signature="(4),(4),(3),(3,3)->(3),(3)")

crazyflow/sim/sim.py

+71-25
Original file line numberDiff line numberDiff line change
@@ -66,37 +66,22 @@ def __init__(
6666

6767
# Initialize MuJoCo world and data
6868
self._xml_path = xml_path or self.default_path
69-
self.spec, self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.setup_mj()
69+
self.spec = self.init_mjx_spec()
70+
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.init_mjx_model(self.spec)
7071
self.viewer: MujocoRenderer | None = None
7172

72-
# Allocate internal states and controls
73-
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(n_drones)]
74-
self.data = SimData(
75-
states=SimState.create(n_worlds, n_drones, self.device),
76-
states_deriv=SimStateDeriv.create(n_worlds, n_drones, self.device),
77-
controls=SimControls.create(
78-
n_worlds, n_drones, state_freq, attitude_freq, thrust_freq, self.device
79-
),
80-
params=SimParams.create(n_worlds, n_drones, MASS, J, J_INV, self.device),
81-
core=SimCore.create(freq, n_worlds, n_drones, drone_ids, rng_key, self.device),
82-
mjx_data=mjx_data,
83-
mjx_model=None,
73+
self.data, self.default_data = self.init_data(
74+
state_freq, attitude_freq, thrust_freq, rng_key, mjx_data
8475
)
85-
if self.n_drones > 1: # If multiple drones, arrange them in a grid
86-
grid = grid_2d(self.n_drones)
87-
states = self.data.states.replace(pos=self.data.states.pos.at[..., :2].set(grid))
88-
self.data: SimData = self.data.replace(states=states)
89-
90-
self.data = self.sync_sim2mjx(self.data, self.mjx_model)
91-
self.default_data = self.data.replace() # TODO: Only save the data of one world
9276

9377
# Default functions for the simulation pipeline
9478
self.disturbance_fn: Callable[[SimData], SimData] | None = None
9579

9680
# Build the simulation pipeline and overwrite the default _step implementation with it
97-
self.build()
81+
self.init_step_fn()
9882

99-
def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
83+
def init_mjx_spec(self) -> mujoco.MjSpec:
84+
"""Build the MuJoCo model specification for the simulation."""
10085
assert self._xml_path.exists(), f"Model file {self._xml_path} does not exist"
10186
spec = mujoco.MjSpec.from_file(str(self._xml_path))
10287
spec.option.timestep = 1 / self.freq
@@ -110,7 +95,10 @@ def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
11095
for i in range(self.n_drones):
11196
drone = frame.attach_body(drone_spec.find_body("drone"), "", f":{i}")
11297
drone.add_freejoint()
113-
# Compile and create data structures
98+
return spec
99+
100+
def init_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
101+
"""Build the MuJoCo model and data structures for the simulation."""
114102
mj_model = spec.compile()
115103
mj_data = mujoco.MjData(mj_model)
116104
mjx_model = mjx.put_model(mj_model, device=self.device)
@@ -120,9 +108,9 @@ def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
120108
# https://github.com/jax-ml/jax/issues/4274#issuecomment-692406759
121109
# Tracking issue: https://github.com/google-deepmind/mujoco/issues/2306
122110
mjx_data = mjx_data.replace(time=jnp.float32(mjx_data.time))
123-
return spec, mj_model, mj_data, mjx_model, mjx_data
111+
return mj_model, mj_data, mjx_model, mjx_data
124112

125-
def build(self):
113+
def init_step_fn(self):
126114
"""Setup the chain of functions that are called in Sim.step().
127115
128116
We know all the functions that are called in succession since the simulation is configured
@@ -179,6 +167,62 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
179167

180168
self._step = step
181169

170+
def init_data(
171+
self, state_freq: int, attitude_freq: int, thrust_freq: int, rng_key: Array, mjx_data: Data
172+
) -> tuple[SimData, SimData]:
173+
"""Initialize the simulation data."""
174+
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(self.n_drones)]
175+
N, D = self.n_worlds, self.n_drones
176+
data = SimData(
177+
states=SimState.create(N, D, self.device),
178+
states_deriv=SimStateDeriv.create(N, D, self.device),
179+
controls=SimControls.create(N, D, state_freq, attitude_freq, thrust_freq, self.device),
180+
params=SimParams.create(N, D, MASS, J, J_INV, self.device),
181+
core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
182+
mjx_data=mjx_data,
183+
mjx_model=None,
184+
)
185+
if D > 1: # If multiple drones, arrange them in a grid
186+
grid = grid_2d(D)
187+
states = data.states.replace(pos=data.states.pos.at[..., :2].set(grid))
188+
data = data.replace(states=states)
189+
data = self.sync_sim2mjx(data, self.mjx_model)
190+
191+
return data, data.replace() # TODO: Only save the data of one world
192+
193+
def build(self, mjx: bool = True, data: bool = True, step: bool = True):
194+
"""Build the simulation pipeline.
195+
196+
This method is used to (re)build the simulation pipeline after changing the MuJoCo
197+
model specification or any of the default functions that are used in the compiled step
198+
function.
199+
200+
Warning:
201+
Depending on what you build, you reset the simulation state. For example, rebuilding the
202+
simulation data will reset the drone states.
203+
204+
Args:
205+
mjx: Flag to (re)build the MuJoCo model and data structures.
206+
data: Flag to (re)build the simulation data.
207+
step: Flag to (re)build the simulation step function.
208+
"""
209+
# TODO: Write tests for all options
210+
if mjx:
211+
if self.viewer is not None:
212+
self.viewer.close()
213+
self.viewer = None
214+
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.init_mjx_model(self.spec)
215+
if data:
216+
self.data, self.default_data = self.init_data(
217+
self.data.controls.state_freq,
218+
self.data.controls.attitude_freq,
219+
self.data.controls.thrust_freq,
220+
self.data.core.rng_key,
221+
self.data.mjx_data if not mjx else mjx_data,
222+
)
223+
if step:
224+
self.init_step_fn()
225+
182226
def reset(self, mask: Array | None = None):
183227
"""Reset the simulation to the initial state.
184228
@@ -236,6 +280,7 @@ def seed(self, seed: int):
236280
def close(self):
237281
if self.viewer is not None:
238282
self.viewer.close()
283+
self.viewer = None
239284

240285
@property
241286
def time(self) -> Array:
@@ -299,6 +344,7 @@ def sync_sim2mjx(data: SimData, mjx_model: Model | None = None) -> SimData:
299344
qvel = rearrange(jnp.concat([vel, local_ang_vel], axis=-1), "w d qvel -> w (d qvel)")
300345
mjx_data = data.mjx_data
301346
mjx_model = data.mjx_model if mjx_model is None else mjx_model
347+
assert mjx_model is not None, "MuJoCo model is not initialized"
302348
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
303349
mjx_data = mjx_kinematics(mjx_model, mjx_data)
304350
mjx_data = mjx_collision(mjx_model, mjx_data)

crazyflow/sim/symbolic.py

+88-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from numpy.typing import NDArray
2121

2222
from crazyflow.constants import ARM_LEN, GRAVITY, SIGN_MIX_MATRIX
23-
from crazyflow.control.control import KF, KM
23+
from crazyflow.control.control import KF, KM, Control
2424
from crazyflow.sim import Sim
2525

2626

@@ -143,9 +143,87 @@ def setup_linearization(self):
143143
self.loss = cs.Function("loss", l_inputs, l_outputs, l_inputs_str, l_outputs_str)
144144

145145

146-
def symbolic(mass: float, J: NDArray, dt: float) -> SymbolicModel:
146+
def symbolic_attitude(dt: float) -> SymbolicModel:
147147
"""Create symbolic (CasADi) models for dynamics, observation, and cost of a quadcopter.
148148
149+
This model is based on the identified model derived from real-world data of the Crazyflie 2.1.
150+
151+
Returns:
152+
The CasADi symbolic model of the environment.
153+
"""
154+
# # Define states.
155+
z = MX.sym("z")
156+
z_dot = MX.sym("z_dot")
157+
g = GRAVITY
158+
# # Set up the dynamics model for a 3D quadrotor.
159+
nx, nu = 12, 4
160+
# Define states.
161+
x = cs.MX.sym("x")
162+
x_dot = cs.MX.sym("x_dot")
163+
y = cs.MX.sym("y")
164+
y_dot = cs.MX.sym("y_dot")
165+
phi = cs.MX.sym("phi") # roll angle [rad]
166+
phi_dot = cs.MX.sym("phi_dot")
167+
theta = cs.MX.sym("theta") # pitch angle [rad]
168+
theta_dot = cs.MX.sym("theta_dot")
169+
psi = cs.MX.sym("psi") # yaw angle [rad]
170+
psi_dot = cs.MX.sym("psi_dot")
171+
X = cs.vertcat(x, x_dot, y, y_dot, z, z_dot, phi, theta, psi, phi_dot, theta_dot, psi_dot)
172+
# Define input collective thrust and theta.
173+
T = cs.MX.sym("T_c") # normalized thrust [N]
174+
R = cs.MX.sym("R_c") # desired roll angle [rad]
175+
P = cs.MX.sym("P_c") # desired pitch angle [rad]
176+
Y = cs.MX.sym("Y_c") # desired yaw angle [rad]
177+
U = cs.vertcat(T, R, P, Y)
178+
# The thrust in PWM is converted from the normalized thrust.
179+
# With the formulat F_desired = b_F * T + a_F
180+
params_acc = [20.907574256269616, 3.653687545690674]
181+
params_roll_rate = [-130.3, -16.33, 119.3]
182+
params_pitch_rate = [-99.94, -13.3, 84.73]
183+
# The identified model sets params_yaw_rate to [0, 0, 0], because the training data did not
184+
# contain any data with yaw != 0. Therefore, it cannot infer the impact of setting the yaw
185+
# attitude to a non-zero value on the dynamics. However, using a zero vector will make the
186+
# system matrix ill-conditioned for control methods like LQR. Therefore, we introduce a small
187+
# spring-like term to the yaw dynamics that leads to a non-singular system matrix.
188+
# TODO: identify proper parameters for yaw_rate from real data.
189+
params_yaw_rate = [-0.01, 0, 0]
190+
191+
# Define dynamics equations.
192+
X_dot = cs.vertcat(
193+
x_dot,
194+
(params_acc[0] * T + params_acc[1])
195+
* (cs.cos(phi) * cs.sin(theta) * cs.cos(psi) + cs.sin(phi) * cs.sin(psi)),
196+
y_dot,
197+
(params_acc[0] * T + params_acc[1])
198+
* (cs.cos(phi) * cs.sin(theta) * cs.sin(psi) - cs.sin(phi) * cs.cos(psi)),
199+
z_dot,
200+
(params_acc[0] * T + params_acc[1]) * cs.cos(phi) * cs.cos(theta) - g,
201+
phi_dot,
202+
theta_dot,
203+
psi_dot,
204+
params_roll_rate[0] * phi + params_roll_rate[1] * phi_dot + params_roll_rate[2] * R,
205+
params_pitch_rate[0] * theta + params_pitch_rate[1] * theta_dot + params_pitch_rate[2] * P,
206+
params_yaw_rate[0] * psi + params_yaw_rate[1] * psi_dot + params_yaw_rate[2] * Y,
207+
)
208+
# Define observation.
209+
Y = cs.vertcat(x, x_dot, y, y_dot, z, z_dot, phi, theta, psi, phi_dot, theta_dot, psi_dot)
210+
211+
# Define cost (quadratic form).
212+
Q, R = MX.sym("Q", nx, nx), MX.sym("R", nu, nu)
213+
Xr, Ur = MX.sym("Xr", nx, 1), MX.sym("Ur", nu, 1)
214+
cost_func = 0.5 * (X - Xr).T @ Q @ (X - Xr) + 0.5 * (U - Ur).T @ R @ (U - Ur)
215+
# Define dynamics and cost dictionaries.
216+
dynamics = {"dyn_eqn": X_dot, "obs_eqn": Y, "vars": {"X": X, "U": U}}
217+
cost = {"cost_func": cost_func, "vars": {"X": X, "U": U, "Xr": Xr, "Ur": Ur, "Q": Q, "R": R}}
218+
return SymbolicModel(dynamics=dynamics, cost=cost, dt=dt)
219+
220+
221+
def symbolic_thrust(mass: float, J: NDArray, dt: float) -> SymbolicModel:
222+
"""Create symbolic (CasADi) models for dynamics, observation, and cost of a quadcopter.
223+
224+
This model is based on the analytical model of Luis, Carlos, and Jérôme Le Ny. "Design of a
225+
trajectory tracking controller for a nanoquadcopter." arXiv preprint arXiv:1608.05786 (2016).
226+
149227
Returns:
150228
The CasADi symbolic model of the environment.
151229
"""
@@ -173,9 +251,6 @@ def symbolic(mass: float, J: NDArray, dt: float) -> SymbolicModel:
173251
f1, f2, f3, f4 = MX.sym("f1"), MX.sym("f2"), MX.sym("f3"), MX.sym("f4")
174252
U = cs.vertcat(f1, f2, f3, f4)
175253

176-
# From Ch. 2 of Luis, Carlos, and Jérôme Le Ny. "Design of a trajectory tracking
177-
# controller for a nanoquadcopter." arXiv preprint arXiv:1608.05786 (2016).
178-
179254
# Defining the dynamics function.
180255
# We are using the velocity of the base wrt to the world frame expressed in the world frame.
181256
# Note that the reference expresses this in the body frame.
@@ -226,8 +301,14 @@ def symbolic_from_sim(sim: Sim) -> SymbolicModel:
226301
The model is expected to deviate from the true dynamics when the sim parameters are
227302
randomized.
228303
"""
229-
mass, J = sim.default_data.params.mass[0, 0], sim.default_data.params.J[0, 0]
230-
return symbolic(mass, J, 1 / sim.freq)
304+
match sim.control:
305+
case Control.attitude:
306+
return symbolic_attitude(1 / sim.freq)
307+
case Control.thrust:
308+
mass, J = sim.default_data.params.mass[0, 0], sim.default_data.params.J[0, 0]
309+
return symbolic_thrust(mass, J, 1 / sim.freq)
310+
case _:
311+
raise ValueError(f"Unsupported control type for symbolic model: {sim.control}")
231312

232313

233314
def csRotXYZ(phi: float, theta: float, psi: float) -> MX:

examples/disturbance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def main(plot: bool = False):
3636

3737
# Second run
3838
sim.disturbance_fn = disturbance_fn
39-
sim.build()
39+
sim.build(mjx=False, data=False, step=True)
4040
pos_disturbed, rpy_disturbed = [], []
4141
sim.reset()
4242
for _ in range(3 * sim.control_freq):

0 commit comments

Comments
 (0)