14
14
from crazyflow .control .controller import J_INV , Control , Controller , J
15
15
from crazyflow .exception import ConfigError , NotInitializedError
16
16
from crazyflow .sim .fused import (
17
+ attitude2rpm ,
17
18
fused_analytical_dynamics ,
18
19
fused_identified_dynamics ,
19
- fused_masked_attitude2rpm ,
20
- fused_masked_state2attitude ,
21
20
fused_rpms2collective_wrench ,
21
+ state2attitude ,
22
22
)
23
23
from crazyflow .sim .integration import Integrator
24
24
from crazyflow .sim .physics import Physics
@@ -125,12 +125,12 @@ def setup_pipeline(self) -> Callable[[int, SimData], SimData]:
125
125
# drones (axis 1) in parallel.
126
126
ctrl_fn = jax .vmap (jax .vmap (self ._control_fn ()))
127
127
physics_fn = jax .vmap (jax .vmap (self ._physics_fn ()))
128
- integrator_fn = jax .vmap (jax .vmap (self ._integrator_fn ()))
128
+ # integrator_fn = jax.vmap(jax.vmap(self._integrator_fn()))
129
129
130
130
def _step (sim_data : SimData ) -> SimData :
131
131
sim_data = ctrl_fn (sim_data )
132
132
sim_data = physics_fn (sim_data )
133
- sim_data = integrator_fn (sim_data )
133
+ # sim_data = integrator_fn(sim_data)
134
134
return sim_data
135
135
136
136
# ``scan`` can be lowered to a single WhileOp, reducing compilation times while still fusing
@@ -233,7 +233,7 @@ def contacts(self, body: str | None = None) -> Array:
233
233
def _control_fn (self ) -> Callable [[SimData ], SimData ]:
234
234
match self .control :
235
235
case Control .state :
236
- return step_state_controller
236
+ return lambda data : step_attitude_controller ( step_state_controller ( data ))
237
237
case Control .attitude :
238
238
return step_attitude_controller
239
239
case _:
@@ -242,9 +242,9 @@ def _control_fn(self) -> Callable[[SimData], SimData]:
242
242
def _physics_fn (self ) -> Callable [[SimData ], SimData ]:
243
243
match self .physics :
244
244
case Physics .analytical :
245
- return self . _step_analytical
245
+ return analytical_dynamics
246
246
case Physics .sys_id :
247
- return self . _step_sys_id
247
+ return identified_dynamics
248
248
case _:
249
249
raise NotImplementedError (f"Physics mode { self .physics } not implemented" )
250
250
@@ -260,8 +260,7 @@ def _step_sys_id(self):
260
260
# Optional optimization: check if mask.any() before updating the controls. This breaks jax's
261
261
# gradient tracing, so we omit it for now.
262
262
if self .control == Control .state :
263
- self .controls = self ._masked_state_controls_update (mask , self .controls )
264
- self .controls = fused_masked_state2attitude (mask , self .states , self .controls , self .dt )
263
+ self .controls = state2attitude (mask , self .states , self .controls , self .dt )
265
264
self .controls = self ._masked_attitude_controls_update (mask , self .controls )
266
265
self .last_ctrl_steps = self ._masked_controls_step_update (
267
266
mask , self .steps , self .last_ctrl_steps
@@ -288,10 +287,9 @@ def _step_analytical(self):
288
287
def _step_emulate_firmware (self ) -> SimControls :
289
288
mask = self .controllable
290
289
if self .control == Control .state :
291
- self .controls = self ._masked_state_controls_update (mask , self .controls )
292
- self .controls = fused_masked_state2attitude (mask , self .states , self .controls , self .dt )
290
+ self .controls = state2attitude (mask , self .states , self .controls , self .dt )
293
291
self .controls = self ._masked_attitude_controls_update (mask , self .controls )
294
- return fused_masked_attitude2rpm (mask , self .states , self .controls , self .dt )
292
+ return attitude2rpm (mask , self .states , self .controls , self .dt )
295
293
296
294
@staticmethod
297
295
def _sync_mjx (states : SimState , mjx_data : Data , mjx_model : Model ) -> Data :
@@ -323,64 +321,52 @@ def _sync_mjx_full(states: SimState, mjx_data: Data, mjx_model: Model) -> Data:
323
321
@staticmethod
324
322
@jax .jit
325
323
def _masked_states_reset (mask : Array , states : SimState , defaults : SimState ) -> SimState :
326
- mask_3d = mask [:, None , None ]
327
- states = states .replace (pos = jnp .where (mask_3d , defaults .pos , states .pos ))
328
- states = states .replace (quat = jnp .where (mask_3d , defaults .quat , states .quat ))
329
- states = states .replace (vel = jnp .where (mask_3d , defaults .vel , states .vel ))
330
- states = states .replace (ang_vel = jnp .where (mask_3d , defaults .ang_vel , states .ang_vel ))
331
- states = states .replace (rpy_rates = jnp .where (mask_3d , defaults .rpy_rates , states .rpy_rates ))
332
- return states
324
+ mask = mask .reshape (- 1 , 1 , 1 )
325
+ return jax .tree .map (lambda x , y : jnp .where (mask , y , x ), states , defaults )
333
326
334
327
@staticmethod
335
328
@jax .jit
336
329
def _masked_controls_reset (
337
330
mask : Array , controls : SimControls , defaults : SimControls
338
331
) -> SimControls :
339
- mask = mask [:, None , None ]
340
- controls = controls .replace (
341
- state = jnp .where (mask , defaults .state , controls .state ),
342
- attitude = jnp .where (mask , defaults .attitude , controls .attitude ),
343
- thrust = jnp .where (mask , defaults .thrust , controls .thrust ),
344
- rpms = jnp .where (mask , defaults .rpms , controls .rpms ),
345
- rpy_err_i = jnp .where (mask , defaults .rpy_err_i , controls .rpy_err_i ),
346
- pos_err_i = jnp .where (mask , defaults .pos_err_i , controls .pos_err_i ),
347
- last_rpy = jnp .where (mask , defaults .last_rpy , controls .last_rpy ),
348
- staged_attitude = jnp .where (mask , defaults .staged_attitude , controls .staged_attitude ),
349
- staged_state = jnp .where (mask , defaults .staged_state , controls .staged_state ),
350
- )
351
- return controls
332
+ mask = mask .reshape (- 1 , 1 , 1 )
333
+ return jax .tree .map (lambda x , y : jnp .where (mask , y , x ), controls , defaults )
352
334
353
335
@staticmethod
354
336
@jax .jit
355
337
def _masked_params_reset (mask : Array , params : SimParams , defaults : SimParams ) -> SimParams :
356
- params = params .replace (mass = jnp .where (mask [:, None , None ], defaults .mass , params .mass ))
357
- mask_4d = mask [:, None , None , None ]
358
- params = params .replace (J = jnp .where (mask_4d , defaults .J , params .J ))
359
- params = params .replace (J_INV = jnp .where (mask_4d , defaults .J_INV , params .J_INV ))
338
+ mask = mask .reshape (- 1 , 1 , 1 )
339
+ params = params .replace (mass = jnp .where (mask , defaults .mass , params .mass ))
340
+ # J and J_INV are matrices -> we need (W, D, N, N) = 4 dims
341
+ mask = mask .reshape (- 1 , 1 , 1 , 1 )
342
+ params = params .replace (J = jnp .where (mask , defaults .J , params .J ))
343
+ params = params .replace (J_INV = jnp .where (mask , defaults .J_INV , params .J_INV ))
360
344
return params
361
345
362
346
@staticmethod
363
347
@partial (jax .jit , static_argnames = "device" )
364
348
def _attitude_control (cmd : Array , controls : SimControls , device : str ) -> SimControls :
349
+ """Stage the desired attitude for all drones in all worlds.
350
+
351
+ We need to stage the attitude commands because the sys_id physics mode operates directly on
352
+ the attitude command. If we were to directly update the controls, this would effectively
353
+ bypass the control frequency and run the attitude controller at the physics update rate. By
354
+ staging the commands, we ensure that the physics module sees the old commands until the
355
+ controller updates at its correct frequency.
356
+ """
365
357
return controls .replace (staged_attitude = jnp .array (cmd , device = device ))
366
358
367
359
@staticmethod
368
360
@partial (jax .jit , static_argnames = "device" )
369
361
def _state_control (cmd : Array , controls : SimControls , device : str ) -> SimControls :
370
- return controls .replace (staged_state = jnp .array (cmd , device = device ))
362
+ return controls .replace (state = jnp .array (cmd , device = device ))
371
363
372
364
@staticmethod
373
365
@jax .jit
374
366
def _masked_attitude_controls_update (mask : Array , controls : SimControls ) -> SimControls :
375
367
cmd , staged_cmd = controls .attitude , controls .staged_attitude
376
368
return controls .replace (attitude = jnp .where (mask [:, None , None ], staged_cmd , cmd ))
377
369
378
- @staticmethod
379
- @jax .jit
380
- def _masked_state_controls_update (mask : Array , controls : SimControls ) -> SimControls :
381
- cmd , staged_cmd = controls .state , controls .staged_state
382
- return controls .replace (state = jnp .where (mask [:, None , None ], staged_cmd , cmd ))
383
-
384
370
@staticmethod
385
371
@jax .jit
386
372
def _masked_controls_step_update (mask : Array , steps : Array , last_ctrl_steps : Array ) -> Array :
@@ -403,24 +389,30 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array:
403
389
404
390
405
391
def step_state_controller (data : SimData ) -> SimData :
392
+ """Compute the updated controls for the state controller."""
406
393
controls = data .controls
407
- mask = controllable (data .steps , controls .steps , data .freq , controls .state_freq )
408
- controls = commit_state_controls ( mask , controls )
394
+ mask = controllable (data .steps , controls .state_steps , data .freq , controls .state_freq )
395
+ controls = controls . replace ( state_steps = jnp . where ( mask , data . steps , controls . state_steps ) )
409
396
controls = state2attitude (mask , data .states , controls , 1 / data .freq )
410
397
return data .replace (controls = controls )
411
398
412
399
413
- def controllable (step : Array , ctrl_step : Array , ctrl_freq : int , freq : int ) -> Array :
414
- return (step - ctrl_step ) >= (freq / ctrl_freq )
400
+ def step_attitude_controller (data : SimData ) -> SimData :
401
+ """Compute the updated controls for the attitude controller."""
402
+ controls = data .controls
403
+ mask = controllable (data .steps , controls .attitude_steps , data .freq , controls .attitude_freq )
404
+ controls = commit_attitude_controls (mask , controls )
405
+ controls = attitude2rpm (mask , data .states , controls , 1 / data .freq )
406
+ return data .replace (controls = controls )
415
407
416
408
417
- def commit_state_controls (mask : Array , controls : SimControls ) -> SimControls :
418
- cmd , staged_cmd = controls .state , controls .staged_state
419
- return controls .replace (state = jnp .where (mask [:, None , None ], staged_cmd , cmd ))
409
+ def controllable (step : Array , ctrl_step : Array , ctrl_freq : int , freq : int ) -> Array :
410
+ return (step - ctrl_step ) >= (freq / ctrl_freq )
420
411
421
412
422
- def step_attitude_controller (data : SimData ) -> SimData :
423
- pass
413
+ def commit_attitude_controls (mask : Array , controls : SimControls ) -> SimControls :
414
+ cmd , staged_cmd = controls .attitude , controls .staged_attitude
415
+ return controls .replace (attitude = jnp .where (mask .reshape (- 1 , 1 , 1 ), staged_cmd , cmd ))
424
416
425
417
426
418
mjx_kinematics = jax .vmap (mjx .kinematics , in_axes = (None , 0 ))
0 commit comments