@@ -66,37 +66,22 @@ def __init__(
66
66
67
67
# Initialize MuJoCo world and data
68
68
self ._xml_path = xml_path or self .default_path
69
- self .spec , self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .setup_mj ()
69
+ self .spec = self .init_mjx_spec ()
70
+ self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .init_mjx_model (self .spec )
70
71
self .viewer : MujocoRenderer | None = None
71
72
72
- # Allocate internal states and controls
73
- drone_ids = [self .mj_model .body (f"drone:{ i } " ).id for i in range (n_drones )]
74
- self .data = SimData (
75
- states = SimState .create (n_worlds , n_drones , self .device ),
76
- states_deriv = SimStateDeriv .create (n_worlds , n_drones , self .device ),
77
- controls = SimControls .create (
78
- n_worlds , n_drones , state_freq , attitude_freq , thrust_freq , self .device
79
- ),
80
- params = SimParams .create (n_worlds , n_drones , MASS , J , J_INV , self .device ),
81
- core = SimCore .create (freq , n_worlds , n_drones , drone_ids , rng_key , self .device ),
82
- mjx_data = mjx_data ,
83
- mjx_model = None ,
73
+ self .data , self .default_data = self .init_data (
74
+ state_freq , attitude_freq , thrust_freq , rng_key , mjx_data
84
75
)
85
- if self .n_drones > 1 : # If multiple drones, arrange them in a grid
86
- grid = grid_2d (self .n_drones )
87
- states = self .data .states .replace (pos = self .data .states .pos .at [..., :2 ].set (grid ))
88
- self .data : SimData = self .data .replace (states = states )
89
-
90
- self .data = self .sync_sim2mjx (self .data , self .mjx_model )
91
- self .default_data = self .data .replace () # TODO: Only save the data of one world
92
76
93
77
# Default functions for the simulation pipeline
94
78
self .disturbance_fn : Callable [[SimData ], SimData ] | None = None
95
79
96
80
# Build the simulation pipeline and overwrite the default _step implementation with it
97
- self .build ()
81
+ self .init_step_fn ()
98
82
99
- def setup_mj (self ) -> tuple [Any , Any , Any , Model , Data ]:
83
+ def init_mjx_spec (self ) -> mujoco .MjSpec :
84
+ """Build the MuJoCo model specification for the simulation."""
100
85
assert self ._xml_path .exists (), f"Model file { self ._xml_path } does not exist"
101
86
spec = mujoco .MjSpec .from_file (str (self ._xml_path ))
102
87
spec .option .timestep = 1 / self .freq
@@ -110,7 +95,10 @@ def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
110
95
for i in range (self .n_drones ):
111
96
drone = frame .attach_body (drone_spec .find_body ("drone" ), "" , f":{ i } " )
112
97
drone .add_freejoint ()
113
- # Compile and create data structures
98
+ return spec
99
+
100
+ def init_mjx_model (self , spec : mujoco .MjSpec ) -> tuple [Any , Any , Model , Data ]:
101
+ """Build the MuJoCo model and data structures for the simulation."""
114
102
mj_model = spec .compile ()
115
103
mj_data = mujoco .MjData (mj_model )
116
104
mjx_model = mjx .put_model (mj_model , device = self .device )
@@ -120,9 +108,9 @@ def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
120
108
# https://github.com/jax-ml/jax/issues/4274#issuecomment-692406759
121
109
# Tracking issue: https://github.com/google-deepmind/mujoco/issues/2306
122
110
mjx_data = mjx_data .replace (time = jnp .float32 (mjx_data .time ))
123
- return spec , mj_model , mj_data , mjx_model , mjx_data
111
+ return mj_model , mj_data , mjx_model , mjx_data
124
112
125
- def build (self ):
113
+ def init_step_fn (self ):
126
114
"""Setup the chain of functions that are called in Sim.step().
127
115
128
116
We know all the functions that are called in succession since the simulation is configured
@@ -179,6 +167,62 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
179
167
180
168
self ._step = step
181
169
170
+ def init_data (
171
+ self , state_freq : int , attitude_freq : int , thrust_freq : int , rng_key : Array , mjx_data : Data
172
+ ) -> tuple [SimData , SimData ]:
173
+ """Initialize the simulation data."""
174
+ drone_ids = [self .mj_model .body (f"drone:{ i } " ).id for i in range (self .n_drones )]
175
+ N , D = self .n_worlds , self .n_drones
176
+ data = SimData (
177
+ states = SimState .create (N , D , self .device ),
178
+ states_deriv = SimStateDeriv .create (N , D , self .device ),
179
+ controls = SimControls .create (N , D , state_freq , attitude_freq , thrust_freq , self .device ),
180
+ params = SimParams .create (N , D , MASS , J , J_INV , self .device ),
181
+ core = SimCore .create (self .freq , N , D , drone_ids , rng_key , self .device ),
182
+ mjx_data = mjx_data ,
183
+ mjx_model = None ,
184
+ )
185
+ if D > 1 : # If multiple drones, arrange them in a grid
186
+ grid = grid_2d (D )
187
+ states = data .states .replace (pos = data .states .pos .at [..., :2 ].set (grid ))
188
+ data = data .replace (states = states )
189
+ data = self .sync_sim2mjx (data , self .mjx_model )
190
+
191
+ return data , data .replace () # TODO: Only save the data of one world
192
+
193
+ def build (self , mjx : bool = True , data : bool = True , step : bool = True ):
194
+ """Build the simulation pipeline.
195
+
196
+ This method is used to (re)build the simulation pipeline after changing the MuJoCo
197
+ model specification or any of the default functions that are used in the compiled step
198
+ function.
199
+
200
+ Warning:
201
+ Depending on what you build, you reset the simulation state. For example, rebuilding the
202
+ simulation data will reset the drone states.
203
+
204
+ Args:
205
+ mjx: Flag to (re)build the MuJoCo model and data structures.
206
+ data: Flag to (re)build the simulation data.
207
+ step: Flag to (re)build the simulation step function.
208
+ """
209
+ # TODO: Write tests for all options
210
+ if mjx :
211
+ if self .viewer is not None :
212
+ self .viewer .close ()
213
+ self .viewer = None
214
+ self .mj_model , self .mj_data , self .mjx_model , mjx_data = self .init_mjx_model (self .spec )
215
+ if data :
216
+ self .data , self .default_data = self .init_data (
217
+ self .data .controls .state_freq ,
218
+ self .data .controls .attitude_freq ,
219
+ self .data .controls .thrust_freq ,
220
+ self .data .core .rng_key ,
221
+ self .data .mjx_data if not mjx else mjx_data ,
222
+ )
223
+ if step :
224
+ self .init_step_fn ()
225
+
182
226
def reset (self , mask : Array | None = None ):
183
227
"""Reset the simulation to the initial state.
184
228
@@ -299,6 +343,7 @@ def sync_sim2mjx(data: SimData, mjx_model: Model | None = None) -> SimData:
299
343
qvel = rearrange (jnp .concat ([vel , local_ang_vel ], axis = - 1 ), "w d qvel -> w (d qvel)" )
300
344
mjx_data = data .mjx_data
301
345
mjx_model = data .mjx_model if mjx_model is None else mjx_model
346
+ assert mjx_model is not None , "MuJoCo model is not initialized"
302
347
mjx_data = mjx_data .replace (qpos = qpos , qvel = qvel )
303
348
mjx_data = mjx_kinematics (mjx_model , mjx_data )
304
349
mjx_data = mjx_collision (mjx_model , mjx_data )
0 commit comments