10
10
from gymnasium .envs .mujoco .mujoco_rendering import MujocoRenderer
11
11
from jax import Array , Device
12
12
from jax .scipy .spatial .transform import Rotation as R
13
- from lsy_models .models_numeric import f_first_principles
13
+ from lsy_models .models_numeric import f_first_principles , f_fitted_DI_rpy
14
14
from mujoco .mjx import Data , Model
15
15
16
16
from crazyflow .constants import J_INV , MASS , SIGN_MIX_MATRIX , J
17
- from crazyflow .control .control import Control , attitude2rpm , pwm2rpm , state2attitude , thrust2pwm
17
+ from crazyflow .control .control import Control , attitude2thrust , pwm2rpm , state2attitude , thrust2pwm
18
18
from crazyflow .exception import ConfigError , NotInitializedError
19
19
from crazyflow .sim .integration import Integrator , euler , rk4
20
- from crazyflow .sim .physics import (
21
- Physics ,
22
- collective_force2acceleration ,
23
- collective_torque2ang_vel_deriv ,
24
- rpms2collective_wrench ,
25
- rpms2motor_forces ,
26
- rpms2motor_torques ,
27
- surrogate_identified_collective_wrench ,
28
- )
20
+ from crazyflow .sim .physics import Physics , rpms2motor_forces , rpms2motor_torques
29
21
from crazyflow .sim .structs import SimControls , SimCore , SimData , SimParams , SimState , SimStateDeriv
30
22
from crazyflow .utils import grid_2d , leaf_replace , patch_viewer , pytree_replace , to_device
31
23
@@ -121,7 +113,6 @@ def build_step_fn(self):
121
113
# functions. They act as factories that produce building blocks for the construction of our
122
114
# simulation pipeline.
123
115
ctrl_fn = generate_control_fn (self .control )
124
- wrench_fn = generate_wrench_fn (self .physics )
125
116
disturbance_fn = self .disturbance_fn
126
117
physics_fn = generate_physics_fn (self .physics , self .integrator )
127
118
sync_fn = generate_sync_fn (self .physics )
@@ -130,7 +121,6 @@ def build_step_fn(self):
130
121
def single_step (data : SimData , _ : None ) -> tuple [SimData , None ]:
131
122
data = ctrl_fn (data )
132
123
data = disturbance_fn (data )
133
- data = wrench_fn (data )
134
124
data = physics_fn (data )
135
125
data = data .replace (core = data .core .replace (steps = data .core .steps + 1 ))
136
126
# MuJoCo needs to sync after every physics step, so that the next step control, wrench
@@ -431,19 +421,6 @@ def generate_control_fn(control: Control) -> Callable[[SimData], SimData]:
431
421
raise NotImplementedError (f"Control mode { control } not implemented" )
432
422
433
423
434
- def generate_wrench_fn (physics : Physics ) -> Callable [[SimData ], SimData ]:
435
- """Generate the wrench function for the given physics mode."""
436
- match physics :
437
- case Physics .analytical :
438
- return analytical_wrench
439
- case Physics .sys_id :
440
- return identified_wrench
441
- case Physics .mujoco :
442
- return mujoco_wrench
443
- case _:
444
- raise NotImplementedError (f"Physics mode { physics } not implemented" )
445
-
446
-
447
424
def generate_derivative_fn (physics : Physics ) -> Callable [[SimData ], SimData ]:
448
425
"""Generate the derivative function for the given physics mode."""
449
426
match physics :
@@ -543,12 +520,12 @@ def step_attitude_controller(data: SimData) -> SimData:
543
520
# Commit the staged attitude controls
544
521
staged_attitude = controls .staged_attitude
545
522
controls = leaf_replace (controls , mask , attitude_steps = steps , attitude = staged_attitude )
546
- # Compute the new rpm values from the committed attitude controls
523
+ # Compute the new thrust values from the committed attitude controls
547
524
quat , attitude = data .states .quat , controls .attitude
548
525
dt = 1 / controls .attitude_freq
549
- rpms , rpy_err_i = attitude2rpm (attitude , quat , controls .last_rpy , controls .rpy_err_i , dt )
526
+ thrust , rpy_err_i = attitude2thrust (attitude , quat , controls .last_rpy , controls .rpy_err_i , dt )
550
527
rpy = R .from_quat (quat ).as_euler ("xyz" )
551
- controls = leaf_replace (controls , mask , rpms = rpms , rpy_err_i = rpy_err_i , last_rpy = rpy )
528
+ controls = leaf_replace (controls , mask , thrust = thrust , rpy_err_i = rpy_err_i , last_rpy = rpy )
552
529
return data .replace (controls = controls )
553
530
554
531
@@ -557,20 +534,14 @@ def step_thrust_controller(data: SimData) -> SimData:
557
534
controls = data .controls
558
535
steps = data .core .steps
559
536
mask = controllable (steps , data .core .freq , controls .thrust_steps , controls .thrust_freq )
560
- rpms = pwm2rpm (thrust2pwm (controls .thrust ))
561
- controls = leaf_replace (controls , mask , thrust_steps = steps , rpms = rpms )
537
+ raise NotImplementedError ("Thrust controller currently not implemented. Missing staging." )
538
+ # TODO: Introduce thrust staging
539
+ controls = leaf_replace (controls , mask , thrust_steps = steps , thrust = controls .thrust )
562
540
return data .replace (controls = controls )
563
541
564
542
565
- def analytical_wrench (data : SimData ) -> SimData :
566
- """Compute the wrench from the analytical dynamics model."""
567
- states , controls , params = data .states , data .controls , data .params
568
- force , torque = rpms2collective_wrench (controls .rpms , states .quat , states .ang_vel , params .J )
569
- return data .replace (states = data .states .replace (force = force , torque = torque ))
570
-
571
-
572
543
def analytical_derivative (data : SimData ) -> SimData :
573
- """Compute the derivative of the states ."""
544
+ """Compute the state derivative from first principles ."""
574
545
dpos , _ , dvel , dang_vel , df_motor = f_first_principles (
575
546
data .states .pos ,
576
547
data .states .quat ,
@@ -588,17 +559,23 @@ def analytical_derivative(data: SimData) -> SimData:
588
559
return data .replace (states_deriv = states_deriv )
589
560
590
561
591
- def identified_wrench (data : SimData ) -> SimData :
592
- """Compute the wrench from the identified dynamics model."""
593
- states , controls = data .states , data .controls
594
- mass , J = data .params .mass , data .params .J
595
- force , torque = surrogate_identified_collective_wrench (
596
- controls .attitude , states .quat , states .ang_vel , mass , J , 1 / data .core .freq
562
+ def identified_derivative (data : SimData ) -> SimData :
563
+ """Compute the state derivative from the identified dynamics model."""
564
+ dpos , _ , dvel , dang_vel , df_motor = f_fitted_DI_rpy (
565
+ data .states .pos ,
566
+ data .states .quat ,
567
+ data .states .vel ,
568
+ data .states .ang_vel ,
569
+ data .controls .thrust ,
570
+ data .params ,
571
+ None , # Fitted model does not have motor dynamics, we assume control can be matched
572
+ data .states .force ,
573
+ data .states .torque ,
597
574
)
598
- return data . replace ( states = data .states .replace (force = force , torque = torque ))
599
-
600
-
601
- identified_derivative = analytical_derivative # We can use the same derivative function for both
575
+ states_deriv = data .states_deriv .replace (
576
+ dpos = dpos , drot = dang_vel , dvel = dvel , dang_vel = dang_vel , dmotor_forces = df_motor
577
+ )
578
+ return data . replace ( states_deriv = states_deriv )
602
579
603
580
604
581
def mujoco_wrench (data : SimData ) -> SimData :
0 commit comments