Skip to content

Commit 67fc2ae

Browse files
committed
Clean up formatting. Add test for symbolic and sim model mismatch
1 parent 472eda4 commit 67fc2ae

File tree

4 files changed

+102
-52
lines changed

4 files changed

+102
-52
lines changed

crazyflow/sim/symbolic.py

+31-31
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from casadi import MX
2020
from numpy.typing import NDArray
2121

22-
from crazyflow.constants import ARM_LEN, GRAVITY, SIGN_MIX_MATRIX
23-
from crazyflow.control.control import KF, KM
22+
from crazyflow.constants import GRAVITY
2423
from crazyflow.sim import Sim
2524

2625

@@ -156,22 +155,22 @@ def symbolic(mass: float, J: NDArray, dt: float) -> SymbolicModel:
156155
# # Set up the dynamics model for a 3D quadrotor.
157156
nx, nu = 12, 4
158157
# Define states.
159-
x = cs.MX.sym('x')
160-
x_dot = cs.MX.sym('x_dot')
161-
y = cs.MX.sym('y')
162-
y_dot = cs.MX.sym('y_dot')
163-
phi = cs.MX.sym('phi') # roll angle [rad]
164-
phi_dot = cs.MX.sym('phi_dot')
165-
theta = cs.MX.sym('theta') # pitch angle [rad]
166-
theta_dot = cs.MX.sym('theta_dot')
167-
psi = cs.MX.sym('psi') # yaw angle [rad]
168-
psi_dot = cs.MX.sym('psi_dot')
158+
x = cs.MX.sym("x")
159+
x_dot = cs.MX.sym("x_dot")
160+
y = cs.MX.sym("y")
161+
y_dot = cs.MX.sym("y_dot")
162+
phi = cs.MX.sym("phi") # roll angle [rad]
163+
phi_dot = cs.MX.sym("phi_dot")
164+
theta = cs.MX.sym("theta") # pitch angle [rad]
165+
theta_dot = cs.MX.sym("theta_dot")
166+
psi = cs.MX.sym("psi") # yaw angle [rad]
167+
psi_dot = cs.MX.sym("psi_dot")
169168
X = cs.vertcat(x, x_dot, y, y_dot, z, z_dot, phi, theta, psi, phi_dot, theta_dot, psi_dot)
170169
# Define input collective thrust and theta.
171-
T = cs.MX.sym('T_c') # normalized thrust [N]
172-
R = cs.MX.sym('R_c') # desired roll angle [rad]
173-
P = cs.MX.sym('P_c') # desired pitch angle [rad]
174-
Y = cs.MX.sym('Y_c') # desired yaw angle [rad]
170+
T = cs.MX.sym("T_c") # normalized thrust [N]
171+
R = cs.MX.sym("R_c") # desired roll angle [rad]
172+
P = cs.MX.sym("P_c") # desired pitch angle [rad]
173+
Y = cs.MX.sym("Y_c") # desired yaw angle [rad]
175174
U = cs.vertcat(T, R, P, Y)
176175
# The thrust in PWM is converted from the normalized thrust.
177176
# With the formulat F_desired = b_F * T + a_F
@@ -185,24 +184,25 @@ def symbolic(mass: float, J: NDArray, dt: float) -> SymbolicModel:
185184

186185
# Define dynamics equations.
187186
# TODO: create a parameter for the new quad model
188-
X_dot = cs.vertcat(x_dot,
189-
(params_acc[0] * T + params_acc[1]) * (
190-
cs.cos(phi) * cs.sin(theta) * cs.cos(psi) + cs.sin(phi) * cs.sin(psi)),
191-
y_dot,
192-
(params_acc[0] * T + params_acc[1]) * (
193-
cs.cos(phi) * cs.sin(theta) * cs.sin(psi) - cs.sin(phi) * cs.cos(psi)),
194-
z_dot,
195-
(params_acc[0] * T + params_acc[1]) * cs.cos(phi) * cs.cos(theta) - g,
196-
phi_dot,
197-
theta_dot,
198-
psi_dot,
199-
params_roll_rate[0] * phi + params_roll_rate[1] * phi_dot + params_roll_rate[2] * R,
200-
params_pitch_rate[0] * theta + params_pitch_rate[1] * theta_dot + params_pitch_rate[2] * P,
201-
params_yaw_rate[0] * psi + params_yaw_rate[1] * psi_dot + params_yaw_rate[2] * Y)
187+
X_dot = cs.vertcat(
188+
x_dot,
189+
(params_acc[0] * T + params_acc[1])
190+
* (cs.cos(phi) * cs.sin(theta) * cs.cos(psi) + cs.sin(phi) * cs.sin(psi)),
191+
y_dot,
192+
(params_acc[0] * T + params_acc[1])
193+
* (cs.cos(phi) * cs.sin(theta) * cs.sin(psi) - cs.sin(phi) * cs.cos(psi)),
194+
z_dot,
195+
(params_acc[0] * T + params_acc[1]) * cs.cos(phi) * cs.cos(theta) - g,
196+
phi_dot,
197+
theta_dot,
198+
psi_dot,
199+
params_roll_rate[0] * phi + params_roll_rate[1] * phi_dot + params_roll_rate[2] * R,
200+
params_pitch_rate[0] * theta + params_pitch_rate[1] * theta_dot + params_pitch_rate[2] * P,
201+
params_yaw_rate[0] * psi + params_yaw_rate[1] * psi_dot + params_yaw_rate[2] * Y,
202+
)
202203
# Define observation.
203204
Y = cs.vertcat(x, x_dot, y, y_dot, z, z_dot, phi, theta, psi, phi_dot, theta_dot, psi_dot)
204205

205-
206206
# Define cost (quadratic form).
207207
Q, R = MX.sym("Q", nx, nx), MX.sym("R", nu, nu)
208208
Xr, Ur = MX.sym("Xr", nx, 1), MX.sym("Ur", nu, 1)

tests/integration/test_symbolic.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
import pytest
3+
from numpy.typing import NDArray
4+
from scipy.spatial.transform import Rotation as R
5+
6+
from crazyflow.control.control import MAX_THRUST, MIN_THRUST
7+
from crazyflow.sim import Sim
8+
from crazyflow.sim.structs import SimState
9+
from crazyflow.sim.symbolic import symbolic_from_sim
10+
11+
12+
def sim_state2symbolic_state(state: SimState) -> NDArray[np.float32]:
13+
"""Convert the simulation state to the symbolic state vector."""
14+
pos = state.pos.squeeze() # shape: (3,)
15+
vel = state.vel.squeeze() # shape: (3,)
16+
euler = R.from_quat(state.quat.squeeze()).as_euler("xyz") # shape: (3,), Euler angles
17+
rpy_rates = state.rpy_rates.squeeze() # shape: (3,)
18+
return np.array([pos[0], vel[0], pos[1], vel[1], pos[2], vel[2], *euler, *rpy_rates])
19+
20+
21+
@pytest.mark.integration
22+
def test_attitude_symbolic():
23+
sim = Sim(physics="sys_id")
24+
sym = symbolic_from_sim(sim)
25+
26+
x0 = np.zeros(12)
27+
28+
# Simulate with both models for 0.5 seconds
29+
t_end = 0.5
30+
dt = 1 / sim.freq
31+
steps = int(t_end / dt)
32+
33+
# Track states over time
34+
x_sym_log = []
35+
x_sim_log = []
36+
37+
# Initialize logs with initial state
38+
x_sym = x0.copy()
39+
x_sim = x0.copy()
40+
x_sym_log.append(x_sym)
41+
x_sim_log.append(x_sim)
42+
43+
u_low = np.array([4 * MIN_THRUST, -np.pi, -np.pi, -np.pi])
44+
u_high = np.array([4 * MAX_THRUST, np.pi, np.pi, np.pi])
45+
46+
# Run simulation
47+
for _ in range(steps):
48+
u_rand = (np.random.rand(4) * (u_high - u_low) + u_low).astype(np.float32)
49+
# Simulate with symbolic model
50+
res = sym.fd_func(x0=x_sym, p=u_rand)
51+
x_sym = res["xf"].full().flatten()
52+
x_sym_log.append(x_sym)
53+
# Simulate with attitude controller
54+
sim.attitude_control(u_rand.reshape(1, 1, 4))
55+
sim.step(sim.freq // sim.control_freq)
56+
x_sim_log.append(sim_state2symbolic_state(sim.data.states))
57+
58+
# Convert logs to arrays. Do not record the rpy rates (deviate easily).
59+
x_sym_log = np.array(x_sym_log)[..., :-3]
60+
x_sim_log = np.array(x_sim_log)[..., :-3]
61+
62+
# Check if states match throughout simulation
63+
err_msg = "Symbolic and simulation prediction do not match approximately"
64+
assert np.allclose(x_sym_log, x_sim_log, rtol=1e-2, atol=1e-3), err_msg
65+
sim.close()

tutorials/LQR_ILQR.ipynb

+5-10
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
"from scipy.spatial.transform import Rotation as R\n",
3434
"\n",
3535
"from crazyflow.constants import GRAVITY, MASS, J\n",
36-
"from crazyflow.sim.physics import Physics\n",
3736
"from crazyflow.control import Control\n",
38-
"from crazyflow.control.control import MAX_THRUST, MIN_THRUST\n",
39-
"from crazyflow.sim.symbolic import symbolic\n"
37+
"from crazyflow.sim.physics import Physics\n",
38+
"from crazyflow.sim.symbolic import symbolic"
4039
]
4140
},
4241
{
@@ -175,9 +174,7 @@
175174
"A, B = df[0].toarray(), df[1].toarray()\n",
176175
"\n",
177176
"print(\"A shape:\", A.shape) # Should be (12, 12)\n",
178-
"print(\"B shape:\", B.shape) # Should be (12, 4)\n",
179-
"# print(\"A :\\n\", A)\n",
180-
"# print(\"B :\\n\", B)"
177+
"print(\"B shape:\", B.shape) # Should be (12, 4)"
181178
]
182179
},
183180
{
@@ -284,7 +281,7 @@
284281
"# gain_lqr = np.dot(np.linalg.inv(R_lqr), np.dot(B.T, P))\n",
285282
"\n",
286283
"print(\"gain:\\n\", gain_lqr)\n",
287-
"print(\"shape of gain:\", gain_lqr.shape)\n"
284+
"print(\"shape of gain:\", gain_lqr.shape)"
288285
]
289286
},
290287
{
@@ -322,12 +319,10 @@
322319
"\n",
323320
" control_input = np.clip(control_input, envs.action_space.low, envs.action_space.high)\n",
324321
" action = control_input.reshape(1,4).astype(np.float32)\n",
325-
" # print(action)\n",
326322
" thrust_log.append(action.flatten())\n",
327323
" obs, reward, terminated, truncated, info = envs.step(action)\n",
328324
"\n",
329325
" state = obs_to_state(obs)\n",
330-
" # print('state:',state)\n",
331326
" x_log.append(state[0])\n",
332327
" y_log.append(state[2])\n",
333328
" z_log.append(state[4])\n",
@@ -377,7 +372,7 @@
377372
"plt.title(\"position vs Time\")\n",
378373
"plt.legend()\n",
379374
"plt.grid()\n",
380-
"plt.show()\n"
375+
"plt.show()"
381376
]
382377
},
383378
{

tutorials/compare_sim_and_symbolic.ipynb

+1-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"from crazyflow.constants import MASS, J\n",
1717
"from crazyflow.control import Control\n",
1818
"from crazyflow.sim.physics import Physics\n",
19-
"from crazyflow.sim.symbolic import symbolic\n"
19+
"from crazyflow.sim.symbolic import symbolic"
2020
]
2121
},
2222
{
@@ -128,7 +128,6 @@
128128
"source": [
129129
"obs, info = envs.reset()\n",
130130
"state = obs_to_state(obs)\n",
131-
"print(state)\n",
132131
"x_sim_log.append(state[0])\n",
133132
"y_sim_log.append(state[2])\n",
134133
"z_sim_log.append(state[4])\n",
@@ -173,26 +172,19 @@
173172
"source": [
174173
"MIN_THRUST = 0.028161688\n",
175174
"MAX_THRUST = 0.14834145\n",
176-
"# shape = (1, 4)\n",
177175
"xf = state\n",
178176
"\n",
179177
"for i in range(200):\n",
180-
" # random_input = np.random.uniform(low=low, high=high, size=shape).astype(np.float32)\n",
181-
"\n",
182178
" random_input = np.array([\n",
183179
" np.random.uniform(4 * MIN_THRUST, 4 * MAX_THRUST), # thrust\n",
184180
" np.random.uniform(-np.pi, np.pi), # \n",
185181
" np.random.uniform(-np.pi, np.pi), # \n",
186182
" np.random.uniform(-np.pi, np.pi) # \n",
187183
" ], dtype=np.float32) \n",
188-
"\n",
189184
" random_input = random_input.reshape((1, 4))\n",
190185
"\n",
191-
" print(\"input:\", random_input)\n",
192-
"\n",
193186
" res = symbolic_model.fd_func(x0=xf, p=random_input)\n",
194187
" xf = res[\"xf\"].full().flatten()\n",
195-
" print(\"xf from symbolic:\", xf)\n",
196188
" x_sym_log.append(xf[0])\n",
197189
" y_sym_log.append(xf[2])\n",
198190
" z_sym_log.append(xf[4])\n",
@@ -209,7 +201,6 @@
209201
" obs, reward, terminated, truncated, info = envs.step(random_input)\n",
210202
" state = obs_to_state(obs)\n",
211203
"\n",
212-
" print(\"state from sim:\", state)\n",
213204
" x_sim_log.append(state[0])\n",
214205
" y_sim_log.append(state[2])\n",
215206
" z_sim_log.append(state[4])\n",
@@ -222,7 +213,6 @@
222213
" roll_dot_sim_log.append(state[9])\n",
223214
" pitch_dot_sim_log.append(state[10])\n",
224215
" yaw_dot_sim_log.append(state[11])\n",
225-
" envs.render()\n",
226216
"envs.sim.close()\n",
227217
"envs.close()"
228218
]

0 commit comments

Comments
 (0)