Skip to content

Commit 438c3eb

Browse files
committed
[wip] Removed links to local constants and controllers
1 parent 3e3a70d commit 438c3eb

File tree

9 files changed

+188
-180
lines changed

9 files changed

+188
-180
lines changed

crazyflow/constants.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
import numpy as np
2-
from numpy.typing import NDArray
3-
4-
# Physical constants
5-
GRAVITY: float = 9.81
6-
7-
# Drone constants
8-
ARM_LEN: float = 0.0325 * np.sqrt(2)
9-
MIX_MATRIX: NDArray = np.array([[-0.5, -0.5, -1], [-0.5, 0.5, 1], [0.5, 0.5, -1], [0.5, -0.5, 1]])
10-
SIGN_MIX_MATRIX: NDArray = np.sign(MIX_MATRIX)
11-
MASS: float = 0.03253
12-
J: NDArray = np.array([[2.3951e-5, 0, 0], [0, 2.3951e-5, 0], [0, 0, 3.2347e-5]])
13-
J_INV: NDArray = np.linalg.inv(J)
1+
# import numpy as np
2+
# from numpy.typing import NDArray
3+
4+
# # Physical constants
5+
# GRAVITY: float = 9.81
6+
7+
# # Drone constants
8+
# ARM_LEN: float = 0.0325 * np.sqrt(2)
9+
# MIX_MATRIX: NDArray = np.array([[-0.5, -0.5, -1], [-0.5, 0.5, 1], [0.5, 0.5, -1], [0.5, -0.5, 1]])
10+
# SIGN_MIX_MATRIX: NDArray = np.sign(MIX_MATRIX)
11+
# MASS: float = 0.03253
12+
# J: NDArray = np.array([[2.3951e-5, 0, 0], [0, 2.3951e-5, 0], [0, 0, 3.2347e-5]])
13+
# J_INV: NDArray = np.linalg.inv(J)
14+
15+
from lsy_models.utils.constants import Constants
16+
17+
constants = Constants.from_config("cf2x_L250") # TODO make dependent on actual drone config

crazyflow/control/control.py

+131-130
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,16 @@
1111
"""
1212

1313
from enum import Enum
14-
from functools import partial
1514

16-
import jax
17-
import jax.numpy as jnp
18-
import numpy as np
19-
from jax import Array
20-
from jax.scipy.spatial.transform import Rotation as R
15+
# from functools import partial
2116

22-
from crazyflow.constants import GRAVITY, MASS, MIX_MATRIX
17+
# import jax
18+
# import jax.numpy as jnp
19+
# import numpy as np
20+
# from jax import Array
21+
# from jax.scipy.spatial.transform import Rotation as R
22+
23+
# from crazyflow.constants import GRAVITY, MASS, MIX_MATRIX
2324

2425

2526
class Control(str, Enum):
@@ -50,126 +51,126 @@ class Control(str, Enum):
5051
default = attitude
5152

5253

53-
KF: float = 3.16e-10
54-
KM: float = 7.94e-12
55-
P_F: Array = np.array([0.4, 0.4, 1.25])
56-
I_F: Array = np.array([0.05, 0.05, 0.05])
57-
D_F: Array = np.array([0.2, 0.2, 0.5])
58-
I_F_RANGE: Array = np.array([2.0, 2.0, 0.4])
59-
P_T: Array = np.array([70000.0, 70000.0, 60000.0])
60-
I_T: Array = np.array([0.0, 0.0, 500.0])
61-
D_T: Array = np.array([20000.0, 20000.0, 12000.0])
62-
PWM2RPM_SCALE: float = 0.2685
63-
PWM2RPM_CONST: float = 4070.3
64-
MIN_PWM: float = 20000
65-
MAX_PWM: float = 65535
66-
MIN_RPM: float = PWM2RPM_SCALE * MIN_PWM + PWM2RPM_CONST
67-
MAX_RPM: float = PWM2RPM_SCALE * MAX_PWM + PWM2RPM_CONST
68-
MIN_THRUST: float = KF * MIN_RPM**2
69-
MAX_THRUST: float = KF * MAX_RPM**2
70-
# Thrust curve parameters for brushed motors
71-
THRUST_CURVE_A: float = -1.1264
72-
THRUST_CURVE_B: float = 2.2541
73-
THRUST_CURVE_C: float = 0.0209
74-
75-
76-
@partial(jnp.vectorize, signature="(3),(3),(4),(3),(3),(1),(3)->(4),(3)", excluded=[7])
77-
def state2attitude(
78-
pos: Array,
79-
vel: Array,
80-
quat: Array,
81-
des_pos: Array,
82-
des_vel: Array,
83-
des_yaw: Array,
84-
i_error: Array,
85-
dt: float,
86-
) -> tuple[Array, Array]:
87-
"""Compute the next desired collective thrust and roll/pitch/yaw of the drone."""
88-
pos_error, vel_error = des_pos - pos, des_vel - vel
89-
# Update integral error
90-
i_error = jnp.clip(i_error + pos_error * dt, -I_F_RANGE, I_F_RANGE)
91-
# Compute target thrust
92-
thrust = P_F * pos_error + I_F * i_error + D_F * vel_error
93-
thrust = thrust.at[2].add(MASS * GRAVITY)
94-
# Update z_axis to the current orientation of the drone
95-
z_axis = R.from_quat(quat).as_matrix()[:, 2]
96-
# Project the thrust onto the z-axis
97-
thrust_desired = jnp.clip(thrust @ z_axis, 0.3 * MASS * GRAVITY, 1.8 * MASS * GRAVITY)
98-
# Update the desired z-axis
99-
z_axis = thrust / jnp.linalg.norm(thrust)
100-
yaw_axis = jnp.concatenate([jnp.cos(des_yaw), jnp.sin(des_yaw), jnp.array([0.0])])
101-
y_axis = jnp.cross(z_axis, yaw_axis)
102-
y_axis = y_axis / jnp.linalg.norm(y_axis)
103-
x_axis = jnp.cross(y_axis, z_axis)
104-
euler_desired = R.from_matrix(jnp.vstack([x_axis, y_axis, z_axis]).T).as_euler("xyz")
105-
return jnp.concatenate([jnp.atleast_1d(thrust_desired), euler_desired]), i_error
106-
107-
108-
@partial(jnp.vectorize, signature="(4),(4),(3),(3)->(4),(3)", excluded=[4])
109-
def attitude2thrust(
110-
controls: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
111-
) -> tuple[Array, Array]:
112-
"""Convert the desired collective thrust and attitude into motor RPMs."""
113-
rot = R.from_quat(quat)
114-
target_rot = R.from_euler("xyz", controls[1:])
115-
drot = (target_rot.inv() * rot).as_matrix()
116-
# Extract the anti-symmetric part of the relative rotation matrix.
117-
rot_e = jnp.array([drot[2, 1] - drot[1, 2], drot[0, 2] - drot[2, 0], drot[1, 0] - drot[0, 1]])
118-
# TODO: Assumes zero rpy_rates targets for now, use the actual target instead.
119-
rpy_rates_e = -(rot.as_euler("xyz") - last_rpy) / dt
120-
rpy_err_i = rpy_err_i - rot_e * dt
121-
rpy_err_i = jnp.clip(rpy_err_i, -1500.0, 1500.0)
122-
rpy_err_i = rpy_err_i.at[:2].set(jnp.clip(rpy_err_i[:2], -1.0, 1.0))
123-
# PID target torques.
124-
target_torques = -P_T * rot_e + D_T * rpy_rates_e + I_T * rpy_err_i
125-
target_torques = jnp.clip(target_torques, -3200, 3200)
126-
thrust_per_motor = jnp.atleast_1d(controls[0]) / 4
127-
pwm = jnp.clip(thrust2pwm(thrust_per_motor) + MIX_MATRIX @ target_torques, MIN_PWM, MAX_PWM)
128-
return pwm2thrust(pwm), rpy_err_i
129-
130-
131-
@partial(jnp.vectorize, signature="(4)->(4)")
132-
def thrust2pwm(thrust: Array) -> Array:
133-
"""Convert the desired thrust into motor PWM.
134-
135-
Args:
136-
thrust: The desired thrust per motor.
137-
138-
Returns:
139-
The motors' PWMs to apply to the quadrotor.
140-
"""
141-
thrust = jnp.clip(thrust, MIN_THRUST, MAX_THRUST) # Protect against NaN values
142-
return jnp.clip((jnp.sqrt(thrust / KF) - PWM2RPM_CONST) / PWM2RPM_SCALE, MIN_PWM, MAX_PWM)
143-
144-
145-
@partial(jnp.vectorize, signature="(4)->(4)")
146-
def pwm2rpm(pwm: Array) -> Array:
147-
"""Convert the motors' PWMs into RPMs."""
148-
return PWM2RPM_CONST + PWM2RPM_SCALE * pwm
149-
150-
151-
@partial(jnp.vectorize, signature="(4)->(4)")
152-
def pwm2thrust(pwm: Array) -> Array:
153-
"""Convert the motors' RPMs into thrust."""
154-
return jnp.clip(((pwm * PWM2RPM_SCALE) + PWM2RPM_CONST) ** 2 * KF, MIN_THRUST, MAX_THRUST)
155-
156-
157-
@jax.jit
158-
def thrust_curve(thrust: Array) -> Array:
159-
"""Compute the quadratic thrust curve of the crazyflie.
160-
161-
Warning:
162-
This function is not used by the simulation. It is only used as interface to the firmware.
163-
164-
Todo:
165-
Find out where this function is used in the firmware and emulate its use in our onboard
166-
controller reimplementation.
167-
168-
Args:
169-
thrust: The desired motor thrust.
170-
171-
Returns:
172-
The motors' PWMs to apply to the quadrotor.
173-
"""
174-
tau = THRUST_CURVE_A * thrust**2 + THRUST_CURVE_B * thrust + THRUST_CURVE_C
175-
return jnp.clip(tau * MAX_PWM, MIN_PWM, MAX_PWM)
54+
# KF: float = 3.16e-10
55+
# KM: float = 7.94e-12
56+
# P_F: Array = np.array([0.4, 0.4, 1.25])
57+
# I_F: Array = np.array([0.05, 0.05, 0.05])
58+
# D_F: Array = np.array([0.2, 0.2, 0.5])
59+
# I_F_RANGE: Array = np.array([2.0, 2.0, 0.4])
60+
# P_T: Array = np.array([70000.0, 70000.0, 60000.0])
61+
# I_T: Array = np.array([0.0, 0.0, 500.0])
62+
# D_T: Array = np.array([20000.0, 20000.0, 12000.0])
63+
# PWM2RPM_SCALE: float = 0.2685
64+
# PWM2RPM_CONST: float = 4070.3
65+
# MIN_PWM: float = 20000
66+
# MAX_PWM: float = 65535
67+
# MIN_RPM: float = PWM2RPM_SCALE * MIN_PWM + PWM2RPM_CONST
68+
# MAX_RPM: float = PWM2RPM_SCALE * MAX_PWM + PWM2RPM_CONST
69+
# THRUST_MIN: float = KF * MIN_RPM**2
70+
# THRUST_MAX: float = KF * MAX_RPM**2
71+
# # Thrust curve parameters for brushed motors
72+
# THRUST_CURVE_A: float = -1.1264
73+
# THRUST_CURVE_B: float = 2.2541
74+
# THRUST_CURVE_C: float = 0.0209
75+
76+
77+
# @partial(jnp.vectorize, signature="(3),(3),(4),(3),(3),(1),(3)->(4),(3)", excluded=[7])
78+
# def state2attitude(
79+
# pos: Array,
80+
# vel: Array,
81+
# quat: Array,
82+
# des_pos: Array,
83+
# des_vel: Array,
84+
# des_yaw: Array,
85+
# i_error: Array,
86+
# dt: float,
87+
# ) -> tuple[Array, Array]:
88+
# """Compute the next desired collective thrust and roll/pitch/yaw of the drone."""
89+
# pos_error, vel_error = des_pos - pos, des_vel - vel
90+
# # Update integral error
91+
# i_error = jnp.clip(i_error + pos_error * dt, -I_F_RANGE, I_F_RANGE)
92+
# # Compute target thrust
93+
# thrust = P_F * pos_error + I_F * i_error + D_F * vel_error
94+
# thrust = thrust.at[2].add(MASS * GRAVITY)
95+
# # Update z_axis to the current orientation of the drone
96+
# z_axis = R.from_quat(quat).as_matrix()[:, 2]
97+
# # Project the thrust onto the z-axis
98+
# thrust_desired = jnp.clip(thrust @ z_axis, 0.3 * MASS * GRAVITY, 1.8 * MASS * GRAVITY)
99+
# # Update the desired z-axis
100+
# z_axis = thrust / jnp.linalg.norm(thrust)
101+
# yaw_axis = jnp.concatenate([jnp.cos(des_yaw), jnp.sin(des_yaw), jnp.array([0.0])])
102+
# y_axis = jnp.cross(z_axis, yaw_axis)
103+
# y_axis = y_axis / jnp.linalg.norm(y_axis)
104+
# x_axis = jnp.cross(y_axis, z_axis)
105+
# euler_desired = R.from_matrix(jnp.vstack([x_axis, y_axis, z_axis]).T).as_euler("xyz")
106+
# return jnp.concatenate([jnp.atleast_1d(thrust_desired), euler_desired]), i_error
107+
108+
109+
# @partial(jnp.vectorize, signature="(4),(4),(3),(3)->(4),(3)", excluded=[4])
110+
# def attitude2thrust(
111+
# controls: Array, quat: Array, last_rpy: Array, rpy_err_i: Array, dt: float
112+
# ) -> tuple[Array, Array]:
113+
# """Convert the desired collective thrust and attitude into motor RPMs."""
114+
# rot = R.from_quat(quat)
115+
# target_rot = R.from_euler("xyz", controls[1:])
116+
# drot = (target_rot.inv() * rot).as_matrix()
117+
# # Extract the anti-symmetric part of the relative rotation matrix.
118+
# rot_e = jnp.array([drot[2, 1] - drot[1, 2], drot[0, 2] - drot[2, 0], drot[1, 0] - drot[0, 1]])
119+
# # TODO: Assumes zero rpy_rates targets for now, use the actual target instead.
120+
# rpy_rates_e = -(rot.as_euler("xyz") - last_rpy) / dt
121+
# rpy_err_i = rpy_err_i - rot_e * dt
122+
# rpy_err_i = jnp.clip(rpy_err_i, -1500.0, 1500.0)
123+
# rpy_err_i = rpy_err_i.at[:2].set(jnp.clip(rpy_err_i[:2], -1.0, 1.0))
124+
# # PID target torques.
125+
# target_torques = -P_T * rot_e + D_T * rpy_rates_e + I_T * rpy_err_i
126+
# target_torques = jnp.clip(target_torques, -3200, 3200)
127+
# thrust_per_motor = jnp.atleast_1d(controls[0]) / 4
128+
# pwm = jnp.clip(thrust2pwm(thrust_per_motor) + MIX_MATRIX @ target_torques, MIN_PWM, MAX_PWM)
129+
# return pwm2thrust(pwm), rpy_err_i
130+
131+
132+
# @partial(jnp.vectorize, signature="(4)->(4)")
133+
# def thrust2pwm(thrust: Array) -> Array:
134+
# """Convert the desired thrust into motor PWM.
135+
136+
# Args:
137+
# thrust: The desired thrust per motor.
138+
139+
# Returns:
140+
# The motors' PWMs to apply to the quadrotor.
141+
# """
142+
# thrust = jnp.clip(thrust, THRUST_MIN, THRUST_MAX) # Protect against NaN values
143+
# return jnp.clip((jnp.sqrt(thrust / KF) - PWM2RPM_CONST) / PWM2RPM_SCALE, MIN_PWM, MAX_PWM)
144+
145+
146+
# @partial(jnp.vectorize, signature="(4)->(4)")
147+
# def pwm2rpm(pwm: Array) -> Array:
148+
# """Convert the motors' PWMs into RPMs."""
149+
# return PWM2RPM_CONST + PWM2RPM_SCALE * pwm
150+
151+
152+
# @partial(jnp.vectorize, signature="(4)->(4)")
153+
# def pwm2thrust(pwm: Array) -> Array:
154+
# """Convert the motors' RPMs into thrust."""
155+
# return jnp.clip(((pwm * PWM2RPM_SCALE) + PWM2RPM_CONST) ** 2 * KF, THRUST_MIN, THRUST_MAX)
156+
157+
158+
# @jax.jit
159+
# def thrust_curve(thrust: Array) -> Array:
160+
# """Compute the quadratic thrust curve of the crazyflie.
161+
162+
# Warning:
163+
# This function is not used by the simulation. It is only used as interface to the firmware.
164+
165+
# Todo:
166+
# Find out where this function is used in the firmware and emulate its use in our onboard
167+
# controller reimplementation.
168+
169+
# Args:
170+
# thrust: The desired motor thrust.
171+
172+
# Returns:
173+
# The motors' PWMs to apply to the quadrotor.
174+
# """
175+
# tau = THRUST_CURVE_A * thrust**2 + THRUST_CURVE_B * thrust + THRUST_CURVE_C
176+
# return jnp.clip(tau * MAX_PWM, MIN_PWM, MAX_PWM)

crazyflow/gymnasium_envs/crazyflow.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from gymnasium.vector.utils import batch_space
1313
from jax import Array
1414

15-
from crazyflow.control.control import MAX_THRUST, MIN_THRUST, Control
15+
from crazyflow.constants import constants
16+
from crazyflow.control.control import Control
1617
from crazyflow.sim import Sim
1718
from crazyflow.sim.structs import SimState
1819

@@ -29,11 +30,11 @@ def action_space(control_type: Control) -> spaces.Box:
2930
match control_type:
3031
case Control.attitude:
3132
return spaces.Box(
32-
np.array([4 * MIN_THRUST, -np.pi, -np.pi, -np.pi], dtype=np.float32),
33-
np.array([4 * MAX_THRUST, np.pi, np.pi, np.pi], dtype=np.float32),
33+
np.array([4 * constants.THRUST_MIN, -np.pi, -np.pi, -np.pi], dtype=np.float32),
34+
np.array([4 * constants.THRUST_MAX, np.pi, np.pi, np.pi], dtype=np.float32),
3435
)
3536
case Control.thrust:
36-
return spaces.Box(MIN_THRUST, MAX_THRUST, shape=(4,))
37+
return spaces.Box(constants.THRUST_MIN, constants.THRUST_MAX, shape=(4,))
3738
case _:
3839
raise ValueError(f"Invalid control type {control_type}")
3940

crazyflow/sim/physics.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from jax.numpy import vectorize
99
from jax.scipy.spatial.transform import Rotation as R
1010

11-
from crazyflow.constants import ARM_LEN, GRAVITY, SIGN_MIX_MATRIX
12-
from crazyflow.control.control import KF, KM
11+
from crazyflow.constants import constants
1312

1413
SYS_ID_PARAMS = {
1514
"acc": np.array([20.907574256269616, 3.653687545690674]),
@@ -80,7 +79,7 @@ def surrogate_identified_collective_wrench(
8079
@partial(vectorize, signature="(3),(1)->(3)")
8180
def collective_force2acceleration(force: Array, mass: Array) -> Array:
8281
"""Convert forces to acceleration."""
83-
return force / mass - jnp.array([0, 0, GRAVITY])
82+
return force / mass - jnp.array([0, 0, constants.GRAVITY])
8483

8584

8685
@partial(vectorize, signature="(3),(4),(3,3)->(3)")
@@ -104,22 +103,22 @@ def rpms2collective_wrench(
104103
@partial(vectorize, signature="(4)->(4)")
105104
def rpms2motor_forces(rpms: Array) -> Array:
106105
"""Convert RPMs to motor forces (body frame, along the z-axis)."""
107-
return rpms**2 * KF
106+
return rpms**2 * constants.KF
108107

109108

110109
@partial(vectorize, signature="(4)->(4)")
111110
def rpms2motor_torques(rpms: Array) -> Array:
112111
"""Convert RPMs to motor torques (body frame, around the z-axis)."""
113-
return rpms**2 * KM
112+
return rpms**2 * constants.KM
114113

115114

116115
@partial(vectorize, signature="(4),(3),(4),(3,3)->(3)")
117116
def rpms2body_torque(rpms: Array, ang_vel: Array, motor_forces: Array, J: Array) -> Array:
118117
"""Convert RPMs to torques in the body frame."""
119118
motor_torques = rpms2motor_torques(rpms)
120-
z_torque = SIGN_MIX_MATRIX[..., 2] @ motor_torques
121-
x_torque = SIGN_MIX_MATRIX[..., 0] @ motor_forces * (ARM_LEN / jnp.sqrt(2))
122-
y_torque = SIGN_MIX_MATRIX[..., 1] @ motor_forces * (ARM_LEN / jnp.sqrt(2))
119+
z_torque = constants.SIGN_MIX_MATRIX[..., 2] @ motor_torques
120+
x_torque = constants.SIGN_MIX_MATRIX[..., 0] @ motor_forces * (constants.L)
121+
y_torque = constants.SIGN_MIX_MATRIX[..., 1] @ motor_forces * (constants.L)
123122
return jnp.array([x_torque, y_torque, z_torque]) - jnp.cross(ang_vel, J @ ang_vel)
124123

125124

0 commit comments

Comments
 (0)