Skip to content

Commit 6670909

Browse files
committed
Add build API to (re)build the simulation after changing functions or MuJoCo specs
1 parent e6dacec commit 6670909

File tree

4 files changed

+73
-29
lines changed

4 files changed

+73
-29
lines changed

crazyflow/sim/sim.py

+70-25
Original file line numberDiff line numberDiff line change
@@ -66,37 +66,22 @@ def __init__(
6666

6767
# Initialize MuJoCo world and data
6868
self._xml_path = xml_path or self.default_path
69-
self.spec, self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.setup_mj()
69+
self.spec = self.init_mjx_spec()
70+
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.init_mjx_model(self.spec)
7071
self.viewer: MujocoRenderer | None = None
7172

72-
# Allocate internal states and controls
73-
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(n_drones)]
74-
self.data = SimData(
75-
states=SimState.create(n_worlds, n_drones, self.device),
76-
states_deriv=SimStateDeriv.create(n_worlds, n_drones, self.device),
77-
controls=SimControls.create(
78-
n_worlds, n_drones, state_freq, attitude_freq, thrust_freq, self.device
79-
),
80-
params=SimParams.create(n_worlds, n_drones, MASS, J, J_INV, self.device),
81-
core=SimCore.create(freq, n_worlds, n_drones, drone_ids, rng_key, self.device),
82-
mjx_data=mjx_data,
83-
mjx_model=None,
73+
self.data, self.default_data = self.init_data(
74+
state_freq, attitude_freq, thrust_freq, rng_key, mjx_data
8475
)
85-
if self.n_drones > 1: # If multiple drones, arrange them in a grid
86-
grid = grid_2d(self.n_drones)
87-
states = self.data.states.replace(pos=self.data.states.pos.at[..., :2].set(grid))
88-
self.data: SimData = self.data.replace(states=states)
89-
90-
self.data = self.sync_sim2mjx(self.data, self.mjx_model)
91-
self.default_data = self.data.replace() # TODO: Only save the data of one world
9276

9377
# Default functions for the simulation pipeline
9478
self.disturbance_fn: Callable[[SimData], SimData] | None = None
9579

9680
# Build the simulation pipeline and overwrite the default _step implementation with it
97-
self.build()
81+
self.init_step_fn()
9882

99-
def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
83+
def init_mjx_spec(self) -> mujoco.MjSpec:
84+
"""Build the MuJoCo model specification for the simulation."""
10085
assert self._xml_path.exists(), f"Model file {self._xml_path} does not exist"
10186
spec = mujoco.MjSpec.from_file(str(self._xml_path))
10287
spec.option.timestep = 1 / self.freq
@@ -110,7 +95,10 @@ def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
11095
for i in range(self.n_drones):
11196
drone = frame.attach_body(drone_spec.find_body("drone"), "", f":{i}")
11297
drone.add_freejoint()
113-
# Compile and create data structures
98+
return spec
99+
100+
def init_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
101+
"""Build the MuJoCo model and data structures for the simulation."""
114102
mj_model = spec.compile()
115103
mj_data = mujoco.MjData(mj_model)
116104
mjx_model = mjx.put_model(mj_model, device=self.device)
@@ -120,9 +108,9 @@ def setup_mj(self) -> tuple[Any, Any, Any, Model, Data]:
120108
# https://github.com/jax-ml/jax/issues/4274#issuecomment-692406759
121109
# Tracking issue: https://github.com/google-deepmind/mujoco/issues/2306
122110
mjx_data = mjx_data.replace(time=jnp.float32(mjx_data.time))
123-
return spec, mj_model, mj_data, mjx_model, mjx_data
111+
return mj_model, mj_data, mjx_model, mjx_data
124112

125-
def build(self):
113+
def init_step_fn(self):
126114
"""Setup the chain of functions that are called in Sim.step().
127115
128116
We know all the functions that are called in succession since the simulation is configured
@@ -179,6 +167,62 @@ def step(data: SimData, n_steps: int = 1) -> SimData:
179167

180168
self._step = step
181169

170+
def init_data(
171+
self, state_freq: int, attitude_freq: int, thrust_freq: int, rng_key: Array, mjx_data: Data
172+
) -> tuple[SimData, SimData]:
173+
"""Initialize the simulation data."""
174+
drone_ids = [self.mj_model.body(f"drone:{i}").id for i in range(self.n_drones)]
175+
N, D = self.n_worlds, self.n_drones
176+
data = SimData(
177+
states=SimState.create(N, D, self.device),
178+
states_deriv=SimStateDeriv.create(N, D, self.device),
179+
controls=SimControls.create(N, D, state_freq, attitude_freq, thrust_freq, self.device),
180+
params=SimParams.create(N, D, MASS, J, J_INV, self.device),
181+
core=SimCore.create(self.freq, N, D, drone_ids, rng_key, self.device),
182+
mjx_data=mjx_data,
183+
mjx_model=None,
184+
)
185+
if D > 1: # If multiple drones, arrange them in a grid
186+
grid = grid_2d(D)
187+
states = data.states.replace(pos=data.states.pos.at[..., :2].set(grid))
188+
data = data.replace(states=states)
189+
data = self.sync_sim2mjx(data, self.mjx_model)
190+
191+
return data, data.replace() # TODO: Only save the data of one world
192+
193+
def build(self, mjx: bool = True, data: bool = True, step: bool = True):
194+
"""Build the simulation pipeline.
195+
196+
This method is used to (re)build the simulation pipeline after changing the MuJoCo
197+
model specification or any of the default functions that are used in the compiled step
198+
function.
199+
200+
Warning:
201+
Depending on what you build, you reset the simulation state. For example, rebuilding the
202+
simulation data will reset the drone states.
203+
204+
Args:
205+
mjx: Flag to (re)build the MuJoCo model and data structures.
206+
data: Flag to (re)build the simulation data.
207+
step: Flag to (re)build the simulation step function.
208+
"""
209+
# TODO: Write tests for all options
210+
if mjx:
211+
if self.viewer is not None:
212+
self.viewer.close()
213+
self.viewer = None
214+
self.mj_model, self.mj_data, self.mjx_model, mjx_data = self.init_mjx_model(self.spec)
215+
if data:
216+
self.data, self.default_data = self.init_data(
217+
self.data.controls.state_freq,
218+
self.data.controls.attitude_freq,
219+
self.data.controls.thrust_freq,
220+
self.data.core.rng_key,
221+
self.data.mjx_data if not mjx else mjx_data,
222+
)
223+
if step:
224+
self.init_step_fn()
225+
182226
def reset(self, mask: Array | None = None):
183227
"""Reset the simulation to the initial state.
184228
@@ -299,6 +343,7 @@ def sync_sim2mjx(data: SimData, mjx_model: Model | None = None) -> SimData:
299343
qvel = rearrange(jnp.concat([vel, local_ang_vel], axis=-1), "w d qvel -> w (d qvel)")
300344
mjx_data = data.mjx_data
301345
mjx_model = data.mjx_model if mjx_model is None else mjx_model
346+
assert mjx_model is not None, "MuJoCo model is not initialized"
302347
mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel)
303348
mjx_data = mjx_kinematics(mjx_model, mjx_data)
304349
mjx_data = mjx_collision(mjx_model, mjx_data)

examples/disturbance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def main(plot: bool = False):
3636

3737
# Second run
3838
sim.disturbance_fn = disturbance_fn
39-
sim.build()
39+
sim.build(mjx=False, data=False, step=True)
4040
pos_disturbed, rpy_disturbed = [], []
4141
sim.reset()
4242
for _ in range(3 * sim.control_freq):

tests/integration/test_disturbance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def test_disturbance(physics: Physics):
2929

3030
sim.reset()
3131
sim.disturbance_fn = disturbance_fn
32-
sim.build()
32+
sim.build(mjx=False, data=False, step=True)
3333
for _ in range(sim.control_freq):
3434
sim.state_control(control)
3535
sim.step(sim.freq // sim.control_freq)

tests/unit/test_sim.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,7 @@ def test_control_frequency(physics: Physics):
305305
sim_500.step()
306306

307307
sim_1000.state_control(cmd)
308-
sim_1000.step()
309-
sim_1000.step()
308+
sim_1000.step(2)
310309

311310
# Check that the controls are the same
312311
assert np.all(sim_500.data.controls.rpms == sim_1000.data.controls.rpms)

0 commit comments

Comments
 (0)