Skip to content

Commit d983e19

Browse files
committed
[wip] Implemented the Mellinger controller and adjusted the structs
1 parent adaf38b commit d983e19

File tree

2 files changed

+136
-34
lines changed

2 files changed

+136
-34
lines changed

crazyflow/sim/sim.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1111
from jax import Array, Device
1212
from jax.scipy.spatial.transform import Rotation as R
13+
from lsy_models.controllers_numeric import cntrl_mellinger_attitude, cntrl_mellinger_position
1314
from lsy_models.models_numeric import f_first_principles, f_fitted_DI_rpy
15+
from lsy_models.utils.cf2 import force2pwm, pwm2force
1416
from mujoco.mjx import Data, Model
1517

1618
from crazyflow.constants import J_INV, MASS, SIGN_MIX_MATRIX, J
@@ -164,7 +166,7 @@ def init_data(
164166
states=SimState.create(N, D, self.device),
165167
states_deriv=SimStateDeriv.create(N, D, self.device),
166168
controls=SimControls.create(N, D, state_freq, attitude_freq, thrust_freq, self.device),
167-
params=SimParams.create(N, D, MASS, J, J_INV, self.device),
169+
params=SimParams.create(N, D, mass=MASS, J=J, device=self.device),
168170
core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
169171
mjx_data=mjx_data,
170172
mjx_model=None,
@@ -503,11 +505,34 @@ def step_state_controller(data: SimData) -> SimData:
503505
des_pos, des_vel = controls.state[..., :3], controls.state[..., 3:6]
504506
des_yaw = controls.state[..., [9]] # Keep (N, M, 1) shape for broadcasting
505507
dt = 1 / data.controls.state_freq
508+
506509
attitude, pos_err_i = state2attitude(
507510
states.pos, states.vel, states.quat, des_pos, des_vel, des_yaw, controls.pos_err_i, dt
508511
)
512+
# Bringing the command into the correct format
513+
command_RPYT = jnp.roll(attitude, -1, axis=-1) # bring into RPYT format
514+
command_RPYT = command_RPYT.at[..., -1].set(
515+
force2pwm(command_RPYT[..., -1], data.params)
516+
) # thrust (N) -> thrust (PWM)
517+
command_RPYT = command_RPYT.at[..., :-1].set(command_RPYT[..., :-1] * 180 / jnp.pi) # rad2deg
518+
519+
# command_RPYT, pos_err_i = cntrl_mellinger_position(
520+
# states.pos,
521+
# states.quat,
522+
# states.vel,
523+
# states.ang_vel,
524+
# controls.state,
525+
# data.params,
526+
# dt=1 / data.controls.state_freq,
527+
# i_error=controls.pos_err_i,
528+
# )
529+
509530
controls = leaf_replace(
510-
controls, mask, state_steps=data.core.steps, staged_attitude=attitude, pos_err_i=pos_err_i
531+
controls,
532+
mask,
533+
state_steps=data.core.steps,
534+
staged_attitude=command_RPYT,
535+
pos_err_i=pos_err_i,
511536
)
512537
return data.replace(controls=controls)
513538

@@ -520,12 +545,24 @@ def step_attitude_controller(data: SimData) -> SimData:
520545
# Commit the staged attitude controls
521546
staged_attitude = controls.staged_attitude
522547
controls = leaf_replace(controls, mask, attitude_steps=steps, attitude=staged_attitude)
523-
# Compute the new thrust values from the committed attitude controls
524-
quat, attitude = data.states.quat, controls.attitude
525-
dt = 1 / controls.attitude_freq
526-
thrust, rpy_err_i = attitude2thrust(attitude, quat, controls.last_rpy, controls.rpy_err_i, dt)
527-
rpy = R.from_quat(quat).as_euler("xyz")
528-
controls = leaf_replace(controls, mask, thrust=thrust, rpy_err_i=rpy_err_i, last_rpy=rpy)
548+
549+
# Calculating the control by stepping the controller
550+
forces, rpy_err_i = cntrl_mellinger_attitude(
551+
data.states.pos,
552+
data.states.quat,
553+
data.states.vel,
554+
data.states.ang_vel,
555+
data.controls.attitude, # command_RPYT
556+
data.params,
557+
dt=1 / controls.attitude_freq,
558+
i_error_m=controls.rpy_err_i,
559+
prev_angular_vel=controls.prev_ang_vel,
560+
)
561+
562+
controls = leaf_replace(
563+
controls, mask, thrust=forces, rpy_err_i=rpy_err_i, prev_ang_vel=data.states.ang_vel
564+
)
565+
529566
return data.replace(controls=controls)
530567

531568

crazyflow/sim/structs.py

+91-26
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ class SimControls:
119119
"""Integral of the rpy error."""
120120
pos_err_i: Array # (N, M, 3)
121121
"""Integral of the position error."""
122-
last_rpy: Array # (N, M, 3)
123-
"""Last rpy for 'xyz' euler angles.
122+
prev_ang_vel: Array # (N, M, 3)
123+
"""Aangular velocity from the last controller step.
124124
125125
Required to compute the integral term in the attitude controller.
126126
"""
@@ -144,58 +144,123 @@ def create(
144144
staged_attitude=jnp.zeros((n_worlds, n_drones, 4), device=device),
145145
attitude_steps=-jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device),
146146
attitude_freq=attitude_freq,
147-
thrust=jnp.zeros((n_worlds, n_drones, 4), device=device),
147+
thrust=jnp.ones((n_worlds, n_drones, 4), device=device)
148+
* 0.08, # TODO remove and rather make floor solid!
148149
thrust_steps=-jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device),
149150
thrust_freq=thrust_freq,
150151
rpms=jnp.zeros((n_worlds, n_drones, 4), device=device),
151152
rpy_err_i=jnp.zeros((n_worlds, n_drones, 3), device=device),
152153
pos_err_i=jnp.zeros((n_worlds, n_drones, 3), device=device),
153-
last_rpy=jnp.zeros((n_worlds, n_drones, 3), device=device),
154+
prev_ang_vel=jnp.zeros((n_worlds, n_drones, 3), device=device),
154155
)
155156

156157

157158
@dataclass
158159
class SimParams:
159-
mass: Array # (N, M, 1)
160+
# Variable params (for domain randomization) => (N, M, shape)
161+
MASS: Array # (N, M, 1)
160162
"""Mass of the drone."""
161163
J: Array # (N, M, 3, 3)
162164
"""Inertia matrix of the drone."""
163165
J_INV: Array # (N, M, 3, 3)
164166
"""Inverse of the inertia matrix of the drone."""
167+
L: Array # (N, M, 1)
168+
"""Arm length of the drone, aka distance of the motors from the center of mass."""
165169

166-
# TODO: Remove duplicate definition of constants. Move into constants from lsy_models
167-
THRUST_TAU: float = field(pytree_node=False)
168-
SIGN_MATRIX: NDArray = field(pytree_node=False)
169-
L: float = field(pytree_node=False)
170-
KF: float = field(pytree_node=False)
171-
KM: float = field(pytree_node=False)
172-
GRAVITY_VEC: NDArray = field(pytree_node=False)
173-
MASS: float = field(pytree_node=False)
174-
J_inv: NDArray = field(pytree_node=False)
170+
# TODO maybe params, maybe constants?
171+
KF: float = field(pytree_node=False) # (N, M, 1)
172+
"""RPM squared to Force factor."""
173+
KM: float = field(pytree_node=False) # (N, M, 1)
174+
"""RPM squared to Torque factor."""
175+
THRUST_MIN: float = field(pytree_node=False) # (N, M, 1)
176+
"""Min thrust per motor."""
177+
THRUST_MAX: float = field(pytree_node=False) # (N, M, 1)
178+
"""Max thrust per motor."""
179+
THRUST_TAU: float = field(pytree_node=False) # (N, M, 1)
180+
# TODO maybe N, M, 4, for each of the motors individually
181+
"""Time constant for the thrust dynamics."""
182+
183+
# Constants
184+
GRAVITY_VEC: Array = field(pytree_node=False)
185+
# MIX_MATRIX: Array = field(pytree_node=False) # TODO not needed? => remove
186+
SIGN_MATRIX: Array = field(pytree_node=False)
187+
PWM_MIN: float = field(pytree_node=False)
188+
PWM_MAX: float = field(pytree_node=False)
189+
190+
# System Identification (SI) parameters
191+
SI_ROLL: Array = field(pytree_node=False)
192+
SI_PITCH: Array = field(pytree_node=False)
193+
SI_YAW: Array = field(pytree_node=False)
194+
SI_PARAMS: Array = field(pytree_node=False)
195+
SI_ACC: Array = field(pytree_node=False)
196+
197+
# System Identification parameters for the double integrator (DI) model
198+
DI_ROLL: Array = field(pytree_node=False)
199+
DI_PITCH: Array = field(pytree_node=False)
200+
DI_YAW: Array = field(pytree_node=False)
201+
DI_PARAMS: Array = field(pytree_node=False)
202+
DI_ACC: Array = field(pytree_node=False)
175203

176204
@staticmethod
177205
def create(
178-
n_worlds: int, n_drones: int, mass: float, J: Array, J_INV: Array, device: Device
206+
n_worlds: int,
207+
n_drones: int,
208+
mass: float,
209+
J: Array,
210+
device: Device,
211+
L: float | None = None,
212+
KF: float | None = None,
213+
KM: float | None = None,
214+
THRUST_MIN: float | None = None,
215+
THRUST_MAX: float | None = None,
216+
THRUST_TAU: float | None = None,
179217
) -> SimParams:
180218
"""Create a default set of parameters for the simulation."""
181-
mass = jnp.ones((n_worlds, n_drones, 1), device=device) * mass
182-
j, j_inv = jnp.array(J, device=device), jnp.array(J_INV, device=device)
219+
MASS = jnp.ones((n_worlds, n_drones, 1), device=device) * mass
220+
j = jnp.array(J, device=device)
221+
j_inv = jnp.linalg.inv(j)
183222
J = jnp.tile(j[None, None, :, :], (n_worlds, n_drones, 1, 1))
184223
J_INV = jnp.tile(j_inv[None, None, :, :], (n_worlds, n_drones, 1, 1))
224+
185225
constants = Constants.from_config("cf2x_L250")
186226

227+
if L is None:
228+
L = constants.L
229+
if KF is None:
230+
KF = constants.KF
231+
if KM is None:
232+
KM = constants.KM
233+
if THRUST_MIN is None:
234+
THRUST_MIN = constants.THRUST_MIN
235+
if THRUST_MAX is None:
236+
THRUST_MAX = constants.THRUST_MAX
237+
if THRUST_TAU is None:
238+
THRUST_TAU = constants.THRUST_TAU
239+
187240
return SimParams(
188-
mass=mass,
189-
J=constants.J,
241+
MASS=MASS,
242+
J=J,
190243
J_INV=J_INV,
191-
THRUST_TAU=constants.THRUST_TAU,
192-
SIGN_MATRIX=constants.SIGN_MATRIX,
193-
L=constants.L,
194-
KF=constants.KF,
195-
KM=constants.KM,
244+
L=L,
245+
KF=KF,
246+
KM=KM,
247+
THRUST_MIN=THRUST_MIN,
248+
THRUST_MAX=THRUST_MAX,
249+
THRUST_TAU=THRUST_TAU * 1, # TODO remove
196250
GRAVITY_VEC=constants.GRAVITY_VEC,
197-
MASS=constants.MASS,
198-
J_inv=constants.J_inv,
251+
SIGN_MATRIX=constants.SIGN_MATRIX,
252+
PWM_MIN=constants.PWM_MIN,
253+
PWM_MAX=constants.PWM_MAX,
254+
SI_ROLL=constants.SI_ROLL,
255+
SI_PITCH=constants.SI_PITCH,
256+
SI_YAW=constants.SI_YAW,
257+
SI_PARAMS=constants.SI_PARAMS,
258+
SI_ACC=constants.SI_ACC,
259+
DI_ROLL=constants.DI_ROLL,
260+
DI_PITCH=constants.DI_PITCH,
261+
DI_YAW=constants.DI_YAW,
262+
DI_PARAMS=constants.DI_PARAMS,
263+
DI_ACC=constants.DI_ACC,
199264
)
200265

201266

0 commit comments

Comments
 (0)