Skip to content

Commit 81cc1be

Browse files
committed
[wip] Work towards pipelined sim
1 parent bd89276 commit 81cc1be

File tree

6 files changed

+184
-86
lines changed

6 files changed

+184
-86
lines changed

crazyflow/sim/core.py

+44-52
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from crazyflow.control.controller import J_INV, Control, Controller, J
1515
from crazyflow.exception import ConfigError, NotInitializedError
1616
from crazyflow.sim.fused import (
17+
attitude2rpm,
1718
fused_analytical_dynamics,
1819
fused_identified_dynamics,
19-
fused_masked_attitude2rpm,
20-
fused_masked_state2attitude,
2120
fused_rpms2collective_wrench,
21+
state2attitude,
2222
)
2323
from crazyflow.sim.integration import Integrator
2424
from crazyflow.sim.physics import Physics
@@ -125,12 +125,12 @@ def setup_pipeline(self) -> Callable[[int, SimData], SimData]:
125125
# drones (axis 1) in parallel.
126126
ctrl_fn = jax.vmap(jax.vmap(self._control_fn()))
127127
physics_fn = jax.vmap(jax.vmap(self._physics_fn()))
128-
integrator_fn = jax.vmap(jax.vmap(self._integrator_fn()))
128+
# integrator_fn = jax.vmap(jax.vmap(self._integrator_fn()))
129129

130130
def _step(sim_data: SimData) -> SimData:
131131
sim_data = ctrl_fn(sim_data)
132132
sim_data = physics_fn(sim_data)
133-
sim_data = integrator_fn(sim_data)
133+
# sim_data = integrator_fn(sim_data)
134134
return sim_data
135135

136136
# ``scan`` can be lowered to a single WhileOp, reducing compilation times while still fusing
@@ -233,7 +233,7 @@ def contacts(self, body: str | None = None) -> Array:
233233
def _control_fn(self) -> Callable[[SimData], SimData]:
234234
match self.control:
235235
case Control.state:
236-
return step_state_controller
236+
return lambda data: step_attitude_controller(step_state_controller(data))
237237
case Control.attitude:
238238
return step_attitude_controller
239239
case _:
@@ -242,9 +242,9 @@ def _control_fn(self) -> Callable[[SimData], SimData]:
242242
def _physics_fn(self) -> Callable[[SimData], SimData]:
243243
match self.physics:
244244
case Physics.analytical:
245-
return self._step_analytical
245+
return analytical_dynamics
246246
case Physics.sys_id:
247-
return self._step_sys_id
247+
return identified_dynamics
248248
case _:
249249
raise NotImplementedError(f"Physics mode {self.physics} not implemented")
250250

@@ -260,8 +260,7 @@ def _step_sys_id(self):
260260
# Optional optimization: check if mask.any() before updating the controls. This breaks jax's
261261
# gradient tracing, so we omit it for now.
262262
if self.control == Control.state:
263-
self.controls = self._masked_state_controls_update(mask, self.controls)
264-
self.controls = fused_masked_state2attitude(mask, self.states, self.controls, self.dt)
263+
self.controls = state2attitude(mask, self.states, self.controls, self.dt)
265264
self.controls = self._masked_attitude_controls_update(mask, self.controls)
266265
self.last_ctrl_steps = self._masked_controls_step_update(
267266
mask, self.steps, self.last_ctrl_steps
@@ -288,10 +287,9 @@ def _step_analytical(self):
288287
def _step_emulate_firmware(self) -> SimControls:
289288
mask = self.controllable
290289
if self.control == Control.state:
291-
self.controls = self._masked_state_controls_update(mask, self.controls)
292-
self.controls = fused_masked_state2attitude(mask, self.states, self.controls, self.dt)
290+
self.controls = state2attitude(mask, self.states, self.controls, self.dt)
293291
self.controls = self._masked_attitude_controls_update(mask, self.controls)
294-
return fused_masked_attitude2rpm(mask, self.states, self.controls, self.dt)
292+
return attitude2rpm(mask, self.states, self.controls, self.dt)
295293

296294
@staticmethod
297295
def _sync_mjx(states: SimState, mjx_data: Data, mjx_model: Model) -> Data:
@@ -323,64 +321,52 @@ def _sync_mjx_full(states: SimState, mjx_data: Data, mjx_model: Model) -> Data:
323321
@staticmethod
324322
@jax.jit
325323
def _masked_states_reset(mask: Array, states: SimState, defaults: SimState) -> SimState:
326-
mask_3d = mask[:, None, None]
327-
states = states.replace(pos=jnp.where(mask_3d, defaults.pos, states.pos))
328-
states = states.replace(quat=jnp.where(mask_3d, defaults.quat, states.quat))
329-
states = states.replace(vel=jnp.where(mask_3d, defaults.vel, states.vel))
330-
states = states.replace(ang_vel=jnp.where(mask_3d, defaults.ang_vel, states.ang_vel))
331-
states = states.replace(rpy_rates=jnp.where(mask_3d, defaults.rpy_rates, states.rpy_rates))
332-
return states
324+
mask = mask.reshape(-1, 1, 1)
325+
return jax.tree.map(lambda x, y: jnp.where(mask, y, x), states, defaults)
333326

334327
@staticmethod
335328
@jax.jit
336329
def _masked_controls_reset(
337330
mask: Array, controls: SimControls, defaults: SimControls
338331
) -> SimControls:
339-
mask = mask[:, None, None]
340-
controls = controls.replace(
341-
state=jnp.where(mask, defaults.state, controls.state),
342-
attitude=jnp.where(mask, defaults.attitude, controls.attitude),
343-
thrust=jnp.where(mask, defaults.thrust, controls.thrust),
344-
rpms=jnp.where(mask, defaults.rpms, controls.rpms),
345-
rpy_err_i=jnp.where(mask, defaults.rpy_err_i, controls.rpy_err_i),
346-
pos_err_i=jnp.where(mask, defaults.pos_err_i, controls.pos_err_i),
347-
last_rpy=jnp.where(mask, defaults.last_rpy, controls.last_rpy),
348-
staged_attitude=jnp.where(mask, defaults.staged_attitude, controls.staged_attitude),
349-
staged_state=jnp.where(mask, defaults.staged_state, controls.staged_state),
350-
)
351-
return controls
332+
mask = mask.reshape(-1, 1, 1)
333+
return jax.tree.map(lambda x, y: jnp.where(mask, y, x), controls, defaults)
352334

353335
@staticmethod
354336
@jax.jit
355337
def _masked_params_reset(mask: Array, params: SimParams, defaults: SimParams) -> SimParams:
356-
params = params.replace(mass=jnp.where(mask[:, None, None], defaults.mass, params.mass))
357-
mask_4d = mask[:, None, None, None]
358-
params = params.replace(J=jnp.where(mask_4d, defaults.J, params.J))
359-
params = params.replace(J_INV=jnp.where(mask_4d, defaults.J_INV, params.J_INV))
338+
mask = mask.reshape(-1, 1, 1)
339+
params = params.replace(mass=jnp.where(mask, defaults.mass, params.mass))
340+
# J and J_INV are matrices -> we need (W, D, N, N) = 4 dims
341+
mask = mask.reshape(-1, 1, 1, 1)
342+
params = params.replace(J=jnp.where(mask, defaults.J, params.J))
343+
params = params.replace(J_INV=jnp.where(mask, defaults.J_INV, params.J_INV))
360344
return params
361345

362346
@staticmethod
363347
@partial(jax.jit, static_argnames="device")
364348
def _attitude_control(cmd: Array, controls: SimControls, device: str) -> SimControls:
349+
"""Stage the desired attitude for all drones in all worlds.
350+
351+
We need to stage the attitude commands because the sys_id physics mode operates directly on
352+
the attitude command. If we were to directly update the controls, this would effectively
353+
bypass the control frequency and run the attitude controller at the physics update rate. By
354+
staging the commands, we ensure that the physics module sees the old commands until the
355+
controller updates at its correct frequency.
356+
"""
365357
return controls.replace(staged_attitude=jnp.array(cmd, device=device))
366358

367359
@staticmethod
368360
@partial(jax.jit, static_argnames="device")
369361
def _state_control(cmd: Array, controls: SimControls, device: str) -> SimControls:
370-
return controls.replace(staged_state=jnp.array(cmd, device=device))
362+
return controls.replace(state=jnp.array(cmd, device=device))
371363

372364
@staticmethod
373365
@jax.jit
374366
def _masked_attitude_controls_update(mask: Array, controls: SimControls) -> SimControls:
375367
cmd, staged_cmd = controls.attitude, controls.staged_attitude
376368
return controls.replace(attitude=jnp.where(mask[:, None, None], staged_cmd, cmd))
377369

378-
@staticmethod
379-
@jax.jit
380-
def _masked_state_controls_update(mask: Array, controls: SimControls) -> SimControls:
381-
cmd, staged_cmd = controls.state, controls.staged_state
382-
return controls.replace(state=jnp.where(mask[:, None, None], staged_cmd, cmd))
383-
384370
@staticmethod
385371
@jax.jit
386372
def _masked_controls_step_update(mask: Array, steps: Array, last_ctrl_steps: Array) -> Array:
@@ -403,24 +389,30 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
403389

404390

405391
def step_state_controller(data: SimData) -> SimData:
392+
"""Compute the updated controls for the state controller."""
406393
controls = data.controls
407-
mask = controllable(data.steps, controls.steps, data.freq, controls.state_freq)
408-
controls = commit_state_controls(mask, controls)
394+
mask = controllable(data.steps, controls.state_steps, data.freq, controls.state_freq)
395+
controls = controls.replace(state_steps=jnp.where(mask, data.steps, controls.state_steps))
409396
controls = state2attitude(mask, data.states, controls, 1 / data.freq)
410397
return data.replace(controls=controls)
411398

412399

413-
def controllable(step: Array, ctrl_step: Array, ctrl_freq: int, freq: int) -> Array:
414-
return (step - ctrl_step) >= (freq / ctrl_freq)
400+
def step_attitude_controller(data: SimData) -> SimData:
401+
"""Compute the updated controls for the attitude controller."""
402+
controls = data.controls
403+
mask = controllable(data.steps, controls.attitude_steps, data.freq, controls.attitude_freq)
404+
controls = commit_attitude_controls(mask, controls)
405+
controls = attitude2rpm(mask, data.states, controls, 1 / data.freq)
406+
return data.replace(controls=controls)
415407

416408

417-
def commit_state_controls(mask: Array, controls: SimControls) -> SimControls:
418-
cmd, staged_cmd = controls.state, controls.staged_state
419-
return controls.replace(state=jnp.where(mask[:, None, None], staged_cmd, cmd))
409+
def controllable(step: Array, ctrl_step: Array, ctrl_freq: int, freq: int) -> Array:
410+
return (step - ctrl_step) >= (freq / ctrl_freq)
420411

421412

422-
def step_attitude_controller(data: SimData) -> SimData:
423-
pass
413+
def commit_attitude_controls(mask: Array, controls: SimControls) -> SimControls:
414+
cmd, staged_cmd = controls.attitude, controls.staged_attitude
415+
return controls.replace(attitude=jnp.where(mask.reshape(-1, 1, 1), staged_cmd, cmd))
424416

425417

426418
mjx_kinematics = jax.vmap(mjx.kinematics, in_axes=(None, 0))

crazyflow/sim/fused.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from jax import Array
1212
from jax.scipy.spatial.transform import Rotation as R
1313

14-
from crazyflow.control.controller import attitude2rpm, state2attitude
14+
from crazyflow.control.controller import attitude2rpm as attitude2rpm_ctrl
15+
from crazyflow.control.controller import state2attitude as state2attitude_ctrl
1516
from crazyflow.sim.physics import analytical_dynamics, identified_dynamics, rpms2collective_wrench
1617
from crazyflow.sim.structs import SimControls, SimParams, SimState
1718

@@ -66,9 +67,7 @@ def fused_analytical_dynamics(
6667
@jax.jit
6768
@partial(jax.vmap, in_axes=(0, 0, 0, None))
6869
@partial(jax.vmap, in_axes=(None, 0, 0, None))
69-
def fused_masked_state2attitude(
70-
mask: Array, state: SimState, cmd: SimControls, dt: float
71-
) -> SimControls:
70+
def state2attitude(mask: Array, state: SimState, cmd: SimControls, dt: float) -> SimControls:
7271
"""Compute the next desired collective thrust and roll/pitch/yaw of the drone.
7372
7473
Note:
@@ -86,7 +85,7 @@ def fused_masked_state2attitude(
8685
dt: The simulation time step.
8786
"""
8887
des_pos, des_vel, des_yaw = cmd.state[:3], cmd.state[3:6], cmd.state[9].reshape((1,))
89-
attitude, pos_err_i = state2attitude(
88+
attitude, pos_err_i = state2attitude_ctrl(
9089
state.pos, state.vel, state.quat, des_pos, des_vel, des_yaw, cmd.pos_err_i, dt
9190
)
9291
# Non-branching selection depending on the mask. XLA should be able to optimize a short path
@@ -99,9 +98,7 @@ def fused_masked_state2attitude(
9998
@jax.jit
10099
@partial(jax.vmap, in_axes=(0, 0, 0, None))
101100
@partial(jax.vmap, in_axes=(None, 0, 0, None))
102-
def fused_masked_attitude2rpm(
103-
mask: Array, state: SimState, cmd: SimControls, dt: float
104-
) -> SimControls:
101+
def attitude2rpm(mask: Array, state: SimState, cmd: SimControls, dt: float) -> SimControls:
105102
"""Compute the next desired RPMs of the drone.
106103
107104
Note:
@@ -114,7 +111,7 @@ def fused_masked_attitude2rpm(
114111
cmd: The current simulation controls.
115112
dt: The simulation time step.
116113
"""
117-
rpms, rpy_err_i = attitude2rpm(cmd.attitude, state.quat, cmd.last_rpy, cmd.rpy_err_i, dt)
114+
rpms, rpy_err_i = attitude2rpm_ctrl(cmd.attitude, state.quat, cmd.last_rpy, cmd.rpy_err_i, dt)
118115
# Non-branching selection depending on the mask. See fused_masked_state2attitude for more info.
119116
rpms = jnp.where(mask, rpms, cmd.rpms)
120117
rpy_err_i = jnp.where(mask, rpy_err_i, cmd.rpy_err_i)

crazyflow/sim/integration.py

+50
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,57 @@
11
from enum import Enum
2+
from typing import Callable
3+
4+
from jax import Array
5+
from scipy.spatial.transform import Rotation as R
6+
7+
from crazyflow.sim.structs import SimControls
28

39

410
class Integrator(str, Enum):
511
euler = "euler"
612
# TODO: Implement rk4
713
default = euler # TODO: Replace with rk4
14+
15+
16+
def euler(
17+
deriv_fn: Callable[[Array, Array, Array, Array, SimControls, float], tuple[Array, Array]],
18+
pos: Array,
19+
quat: Array,
20+
vel: Array,
21+
rpy_rates: Array,
22+
control: SimControls,
23+
dt: float,
24+
) -> tuple[Array, Array, Array, Array]:
25+
acc, rpy_rates_deriv = deriv_fn(pos, quat, vel, rpy_rates, control, dt)
26+
return _apply(pos, quat, vel, rpy_rates, acc, rpy_rates_deriv, dt)
27+
28+
29+
def rk4(
30+
deriv_fn: Callable[[Array, Array, Array, Array, SimControls, float], tuple[Array, Array]],
31+
pos: Array,
32+
quat: Array,
33+
vel: Array,
34+
rpy_rates: Array,
35+
control: SimControls,
36+
dt: float,
37+
) -> tuple[Array, Array, Array, Array]:
38+
raise NotImplementedError("RK4 not implemented")
39+
40+
41+
def _apply(
42+
pos: Array,
43+
quat: Array,
44+
vel: Array,
45+
rpy_rates: Array,
46+
acc: Array,
47+
rpy_rates_deriv: Array,
48+
dt: float,
49+
) -> tuple[Array, Array, Array, Array]:
50+
rot = R.from_quat(quat)
51+
rpy_rates_local = rot.apply(rpy_rates, inverse=True)
52+
rpy_rates_deriv_local = rot.apply(rpy_rates_deriv, inverse=True)
53+
next_pos = pos + vel * dt
54+
next_vel = vel + acc * dt
55+
next_rot = R.from_euler("xyz", rot.as_euler("xyz") + rpy_rates_local * dt)
56+
next_rpy_rates_local = rpy_rates_local + rpy_rates_deriv_local * dt
57+
return next_pos, next_rot.as_quat(), next_vel, next_rot.apply(next_rpy_rates_local)

crazyflow/sim/physics.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from crazyflow.constants import ARM_LEN, GRAVITY, SIGN_MIX_MATRIX
88
from crazyflow.control.controller import KF, KM
9+
from crazyflow.sim.structs import SimControls, SimParams, SimState
910

1011
SYS_ID_PARAMS = {
1112
"acc_k1": 20.91,
@@ -69,6 +70,30 @@ def identified_dynamics(
6970
return next_pos, next_quat, next_vel, next_rpy_rates
7071

7172

73+
def identified_dynamics_dx(
74+
state: SimState, controls: SimControls, dt: float
75+
) -> tuple[Array, Array]:
76+
"""Derivative of the identified dynamics state."""
77+
collective_thrust, attitude = controls.attitude[0], controls.attitude[1:]
78+
rot = R.from_quat(state.quat)
79+
thrust = rot.apply(jnp.array([0, 0, collective_thrust]))
80+
drift = rot.apply(jnp.array([0, 0, 1]))
81+
a1, a2 = SYS_ID_PARAMS["acc_k1"], SYS_ID_PARAMS["acc_k2"]
82+
acc = thrust * a1 + drift * a2 - jnp.array([0, 0, GRAVITY])
83+
# rpy_rates_deriv have no real meaning in this context, since the identified dynamics set the
84+
# rpy_rates to the commanded values directly. However, since we use a unified integration
85+
# interface for all physics models, we cannot access states directly. Instead, we calculate
86+
# which rpy_rates_deriv would have resulted in the desired rpy_rates, and return that.
87+
roll_cmd, pitch_cmd, yaw_cmd = attitude
88+
rpy = rot.as_euler("xyz")
89+
roll_rate = SYS_ID_PARAMS["roll_alpha"] * rpy[0] + SYS_ID_PARAMS["roll_beta"] * roll_cmd
90+
pitch_rate = SYS_ID_PARAMS["pitch_alpha"] * rpy[1] + SYS_ID_PARAMS["pitch_beta"] * pitch_cmd
91+
yaw_rate = SYS_ID_PARAMS["yaw_alpha"] * rpy[2] + SYS_ID_PARAMS["yaw_beta"] * yaw_cmd
92+
rpy_rates = jnp.array([roll_rate, pitch_rate, yaw_rate])
93+
rpy_rates_deriv = (rpy_rates - rot.apply(state.rpy_rates, inverse=True)) / dt
94+
return acc, rpy_rates_deriv
95+
96+
7297
def analytical_dynamics(
7398
forces: Array,
7499
torques: Array,
@@ -91,14 +116,25 @@ def analytical_dynamics(
91116
# Update state.
92117
next_pos = pos + vel * dt
93118
next_vel = vel + acc * dt
94-
next_rot = R.from_euler("xyz", R.from_quat(quat).as_euler("xyz") + rpy_rates_local * dt)
119+
next_rot = R.from_euler("xyz", rot.as_euler("xyz") + rpy_rates_local * dt)
95120
next_quat = next_rot.as_quat()
96121
# Convert rpy rates back to global frame
97122
next_rpy_rates_local = rpy_rates_local + rpy_rates_deriv_local * dt
98123
next_rpy_rates = next_rot.apply(next_rpy_rates_local) # Always give rpy rates in world frame
99124
return next_pos, next_quat, next_vel, next_rpy_rates
100125

101126

127+
def analytical_dynamics_dx(
128+
forces: Array, torques: Array, state: SimState, params: SimParams
129+
) -> tuple[Array, Array]:
130+
"""Derivative of the analytical dynamics state."""
131+
rot = R.from_quat(state.quat)
132+
torques_local = rot.apply(torques, inverse=True)
133+
acc = forces / params.mass - jnp.array([0, 0, GRAVITY])
134+
rpy_rates_deriv = rot.apply(params.J_INV @ torques_local)
135+
return acc, rpy_rates_deriv
136+
137+
102138
def rpms2collective_wrench(
103139
rpms: Array, quat: Array, rpy_rates: Array, J: Array
104140
) -> tuple[Array, Array]:

crazyflow/sim/structs.py

-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def default_state(n_worlds: int, n_drones: int, device: Device) -> SimState:
2626
@dataclass
2727
class SimControls:
2828
state: Array # (N, M, 13)
29-
staged_state: Array # (N, M, 13)
3029
attitude: Array # (N, M, 4)
3130
staged_attitude: Array # (N, M, 4)
3231
thrust: Array # (N, M, 4)
@@ -40,7 +39,6 @@ def default_controls(n_worlds: int, n_drones: int, device: Device) -> SimControl
4039
"""Create a default set of controls for the simulation."""
4140
return SimControls(
4241
state=jnp.zeros((n_worlds, n_drones, 13), device=device),
43-
staged_state=jnp.zeros((n_worlds, n_drones, 13), device=device),
4442
attitude=jnp.zeros((n_worlds, n_drones, 4), device=device),
4543
staged_attitude=jnp.zeros((n_worlds, n_drones, 4), device=device),
4644
thrust=jnp.zeros((n_worlds, n_drones, 4), device=device),

0 commit comments

Comments
 (0)