Skip to content

Commit 3ea0c3e

Browse files
committed
[wip,broken] Transition to lsy_models
1 parent 1f59380 commit 3ea0c3e

File tree

3 files changed

+65
-60
lines changed

3 files changed

+65
-60
lines changed

crazyflow/control/control.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def state2attitude(
106106

107107

108108
@partial(jnp.vectorize, signature="(4),(4),(3),(3)->(4),(3)", excluded=[4])
109-
def attitude2rpm(
109+
def attitude2thrust(
110110
controls: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
111111
) -> tuple[Array, Array]:
112112
"""Convert the desired collective thrust and attitude into motor RPMs."""
@@ -125,7 +125,7 @@ def attitude2rpm(
125125
target_torques = jnp.clip(target_torques, -3200, 3200)
126126
thrust_per_motor = jnp.atleast_1d(controls[0]) / 4
127127
pwm = jnp.clip(thrust2pwm(thrust_per_motor) + MIX_MATRIX @ target_torques, MIN_PWM, MAX_PWM)
128-
return pwm2rpm(pwm), rpy_err_i
128+
return pwm2thrust(pwm), rpy_err_i
129129

130130

131131
@partial(jnp.vectorize, signature="(4)->(4)")
@@ -148,6 +148,12 @@ def pwm2rpm(pwm: Array) -> Array:
148148
return PWM2RPM_CONST + PWM2RPM_SCALE * pwm
149149

150150

151+
@partial(jnp.vectorize, signature="(4)->(4)")
152+
def pwm2thrust(pwm: Array) -> Array:
153+
"""Convert the motors' RPMs into thrust."""
154+
return jnp.clip(((pwm * PWM2RPM_SCALE) + PWM2RPM_CONST) ** 2 * KF, MIN_THRUST, MAX_THRUST)
155+
156+
151157
@jax.jit
152158
def thrust_curve(thrust: Array) -> Array:
153159
"""Compute the quadratic thrust curve of the crazyflie.

crazyflow/sim/integration.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -60,40 +60,61 @@ def rk4_average(k1: SimData, k2: SimData, k3: SimData, k4: SimData) -> SimData:
6060
def integrate(data: SimData, deriv: SimData, dt: float) -> SimData:
6161
"""Integrate the dynamics forward in time."""
6262
states, states_deriv = data.states, deriv.states_deriv
63-
pos, quat, vel, ang_vel = states.pos, states.quat, states.vel, states.ang_vel
64-
dpos, drot = states_deriv.dpos, states_deriv.drot
65-
dvel, dang_vel = states_deriv.dvel, states_deriv.dang_vel
66-
next_pos, next_quat, next_vel, next_ang_vel = _integrate(
67-
pos, quat, vel, ang_vel, dpos, drot, dvel, dang_vel, dt
63+
next_pos, next_quat, next_vel, next_ang_vel, next_motor_forces = _integrate(
64+
states.pos,
65+
states.quat,
66+
states.vel,
67+
states.ang_vel,
68+
states.motor_forces,
69+
states_deriv.dpos,
70+
states_deriv.drot,
71+
states_deriv.dvel,
72+
states_deriv.dang_vel,
73+
states_deriv.dmotor_forces,
74+
dt,
6875
)
6976
return data.replace(
70-
states=states.replace(pos=next_pos, quat=next_quat, vel=next_vel, ang_vel=next_ang_vel)
77+
states=states.replace(
78+
pos=next_pos,
79+
quat=next_quat,
80+
vel=next_vel,
81+
ang_vel=next_ang_vel,
82+
motor_forces=next_motor_forces,
83+
)
7184
)
7285

7386

74-
@partial(vectorize, signature="(3),(4),(3),(3),(3),(3),(3),(3)->(3),(4),(3),(3)", excluded=[8])
87+
@partial(
88+
vectorize,
89+
signature="(3),(4),(3),(3),(4),(3),(3),(3),(3),(4)->(3),(4),(3),(3),(4)",
90+
excluded=[10],
91+
)
7592
def _integrate(
7693
pos: Array,
7794
quat: Array,
7895
vel: Array,
7996
ang_vel: Array,
97+
motor_forces: Array,
8098
dpos: Array,
8199
drot: Array,
82100
dvel: Array,
83101
dang_vel: Array,
102+
dmotor_forces: Array,
84103
dt: float,
85-
) -> tuple[Array, Array, Array, Array]:
104+
) -> tuple[Array, Array, Array, Array, Array]:
86105
"""Integrate the dynamics forward in time.
87106
88107
Args:
89108
pos: The position of the drone.
90109
quat: The orientation of the drone as a quaternion.
91110
vel: The velocity of the drone.
92111
ang_vel: The angular velocity of the drone.
112+
motor_forces: The forces of the motors.
93113
dpos: The derivative of the position of the drone.
94114
drot: The derivative of the quaternion of the drone (3D angular velocity).
95115
dvel: The derivative of the velocity of the drone.
96116
dang_vel: The derivative of the angular velocity of the drone.
117+
dmotor_forces: The derivative of the motor forces of the drone.
97118
dt: The time step to integrate over.
98119
99120
Returns:
@@ -103,4 +124,5 @@ def _integrate(
103124
next_quat = (R.from_quat(quat) * R.from_rotvec(drot * dt)).as_quat()
104125
next_vel = vel + dvel * dt
105126
next_ang_vel = ang_vel + dang_vel * dt
106-
return next_pos, next_quat, next_vel, next_ang_vel
127+
next_motor_forces = motor_forces + dmotor_forces * dt
128+
return next_pos, next_quat, next_vel, next_ang_vel, next_motor_forces

crazyflow/sim/sim.py

+26-49
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,14 @@
1010
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1111
from jax import Array, Device
1212
from jax.scipy.spatial.transform import Rotation as R
13-
from lsy_models.models_numeric import f_first_principles
13+
from lsy_models.models_numeric import f_first_principles, f_fitted_DI_rpy
1414
from mujoco.mjx import Data, Model
1515

1616
from crazyflow.constants import J_INV, MASS, SIGN_MIX_MATRIX, J
17-
from crazyflow.control.control import Control, attitude2rpm, pwm2rpm, state2attitude, thrust2pwm
17+
from crazyflow.control.control import Control, attitude2thrust, pwm2rpm, state2attitude, thrust2pwm
1818
from crazyflow.exception import ConfigError, NotInitializedError
1919
from crazyflow.sim.integration import Integrator, euler, rk4
20-
from crazyflow.sim.physics import (
21-
Physics,
22-
collective_force2acceleration,
23-
collective_torque2ang_vel_deriv,
24-
rpms2collective_wrench,
25-
rpms2motor_forces,
26-
rpms2motor_torques,
27-
surrogate_identified_collective_wrench,
28-
)
20+
from crazyflow.sim.physics import Physics, rpms2motor_forces, rpms2motor_torques
2921
from crazyflow.sim.structs import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv
3022
from crazyflow.utils import grid_2d, leaf_replace, patch_viewer, pytree_replace, to_device
3123

@@ -121,7 +113,6 @@ def build_step_fn(self):
121113
# functions. They act as factories that produce building blocks for the construction of our
122114
# simulation pipeline.
123115
ctrl_fn = generate_control_fn(self.control)
124-
wrench_fn = generate_wrench_fn(self.physics)
125116
disturbance_fn = self.disturbance_fn
126117
physics_fn = generate_physics_fn(self.physics, self.integrator)
127118
sync_fn = generate_sync_fn(self.physics)
@@ -130,7 +121,6 @@ def build_step_fn(self):
130121
def single_step(data: SimData, _: None) -> tuple[SimData, None]:
131122
data = ctrl_fn(data)
132123
data = disturbance_fn(data)
133-
data = wrench_fn(data)
134124
data = physics_fn(data)
135125
data = data.replace(core=data.core.replace(steps=data.core.steps + 1))
136126
# MuJoCo needs to sync after every physics step, so that the next step control, wrench
@@ -431,19 +421,6 @@ def generate_control_fn(control: Control) -> Callable[[SimData], SimData]:
431421
raise NotImplementedError(f"Control mode {control} not implemented")
432422

433423

434-
def generate_wrench_fn(physics: Physics) -> Callable[[SimData], SimData]:
435-
"""Generate the wrench function for the given physics mode."""
436-
match physics:
437-
case Physics.analytical:
438-
return analytical_wrench
439-
case Physics.sys_id:
440-
return identified_wrench
441-
case Physics.mujoco:
442-
return mujoco_wrench
443-
case _:
444-
raise NotImplementedError(f"Physics mode {physics} not implemented")
445-
446-
447424
def generate_derivative_fn(physics: Physics) -> Callable[[SimData], SimData]:
448425
"""Generate the derivative function for the given physics mode."""
449426
match physics:
@@ -543,12 +520,12 @@ def step_attitude_controller(data: SimData) -> SimData:
543520
# Commit the staged attitude controls
544521
staged_attitude = controls.staged_attitude
545522
controls = leaf_replace(controls, mask, attitude_steps=steps, attitude=staged_attitude)
546-
# Compute the new rpm values from the committed attitude controls
523+
# Compute the new thrust values from the committed attitude controls
547524
quat, attitude = data.states.quat, controls.attitude
548525
dt = 1 / controls.attitude_freq
549-
rpms, rpy_err_i = attitude2rpm(attitude, quat, controls.last_rpy, controls.rpy_err_i, dt)
526+
thrust, rpy_err_i = attitude2thrust(attitude, quat, controls.last_rpy, controls.rpy_err_i, dt)
550527
rpy = R.from_quat(quat).as_euler("xyz")
551-
controls = leaf_replace(controls, mask, rpms=rpms, rpy_err_i=rpy_err_i, last_rpy=rpy)
528+
controls = leaf_replace(controls, mask, thrust=thrust, rpy_err_i=rpy_err_i, last_rpy=rpy)
552529
return data.replace(controls=controls)
553530

554531

@@ -557,20 +534,14 @@ def step_thrust_controller(data: SimData) -> SimData:
557534
controls = data.controls
558535
steps = data.core.steps
559536
mask = controllable(steps, data.core.freq, controls.thrust_steps, controls.thrust_freq)
560-
rpms = pwm2rpm(thrust2pwm(controls.thrust))
561-
controls = leaf_replace(controls, mask, thrust_steps=steps, rpms=rpms)
537+
raise NotImplementedError("Thrust controller currently not implemented. Missing staging.")
538+
# TODO: Introduce thrust staging
539+
controls = leaf_replace(controls, mask, thrust_steps=steps, thrust=controls.thrust)
562540
return data.replace(controls=controls)
563541

564542

565-
def analytical_wrench(data: SimData) -> SimData:
566-
"""Compute the wrench from the analytical dynamics model."""
567-
states, controls, params = data.states, data.controls, data.params
568-
force, torque = rpms2collective_wrench(controls.rpms, states.quat, states.ang_vel, params.J)
569-
return data.replace(states=data.states.replace(force=force, torque=torque))
570-
571-
572543
def analytical_derivative(data: SimData) -> SimData:
573-
"""Compute the derivative of the states."""
544+
"""Compute the state derivative from first principles."""
574545
dpos, _, dvel, dang_vel, df_motor = f_first_principles(
575546
data.states.pos,
576547
data.states.quat,
@@ -588,17 +559,23 @@ def analytical_derivative(data: SimData) -> SimData:
588559
return data.replace(states_deriv=states_deriv)
589560

590561

591-
def identified_wrench(data: SimData) -> SimData:
592-
"""Compute the wrench from the identified dynamics model."""
593-
states, controls = data.states, data.controls
594-
mass, J = data.params.mass, data.params.J
595-
force, torque = surrogate_identified_collective_wrench(
596-
controls.attitude, states.quat, states.ang_vel, mass, J, 1 / data.core.freq
562+
def identified_derivative(data: SimData) -> SimData:
563+
"""Compute the state derivative from the identified dynamics model."""
564+
dpos, _, dvel, dang_vel, df_motor = f_fitted_DI_rpy(
565+
data.states.pos,
566+
data.states.quat,
567+
data.states.vel,
568+
data.states.ang_vel,
569+
data.controls.thrust,
570+
data.params,
571+
None, # Fitted model does not have motor dynamics, we assume control can be matched
572+
data.states.force,
573+
data.states.torque,
597574
)
598-
return data.replace(states=data.states.replace(force=force, torque=torque))
599-
600-
601-
identified_derivative = analytical_derivative # We can use the same derivative function for both
575+
states_deriv = data.states_deriv.replace(
576+
dpos=dpos, drot=dang_vel, dvel=dvel, dang_vel=dang_vel, dmotor_forces=df_motor
577+
)
578+
return data.replace(states_deriv=states_deriv)
602579

603580

604581
def mujoco_wrench(data: SimData) -> SimData:

0 commit comments

Comments
 (0)