Skip to content

Commit f1b2976

Browse files
committed
Refactor jit'ed functions
1 parent b451fd1 commit f1b2976

File tree

1 file changed

+66
-108
lines changed

1 file changed

+66
-108
lines changed

lsy_drone_racing/envs/multi_drone_race.py

+66-108
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
randomize_gate_rpy_fn,
3131
randomize_obstacle_pos_fn,
3232
)
33-
from lsy_drone_racing.utils import check_gate_pass
3433

3534
if TYPE_CHECKING:
3635
from crazyflow.sim.structs import SimData
@@ -227,7 +226,7 @@ def step(
227226
self.sim.step(self.sim.freq // self.freq)
228227
# TODO: Clean up the accelerated functions
229228
self.disabled_drones = np.array(
230-
self.update_active_drones_acc(
229+
self._disabled_drones(
231230
self.sim.data.states.pos[0],
232231
self.sim.data.states.quat[0],
233232
self.pos_bounds.low,
@@ -242,7 +241,7 @@ def step(
242241
)
243242
self.sim.data = self.warp_disabled_drones(self.sim.data, self.disabled_drones)
244243
# TODO: Clean up the accelerated functions
245-
passed = self.gate_passed_accelerated(
244+
passed = self._gate_passed(
246245
self.target_gate,
247246
self.gates["mocap_ids"],
248247
self.sim.data.mjx_data.mocap_pos[0],
@@ -263,21 +262,11 @@ def render(self):
263262
def obs(self) -> dict[str, NDArray[np.floating]]:
264263
"""Return the observation of the environment."""
265264
# TODO: Accelerate this function
266-
obs = {
267-
"pos": np.array(self.sim.data.states.pos[0], dtype=np.float32),
268-
"rpy": R.from_quat(self.sim.data.states.quat[0]).as_euler("xyz").astype(np.float32),
269-
"vel": np.array(self.sim.data.states.vel[0], dtype=np.float32),
270-
"ang_vel": np.array(self.sim.data.states.rpy_rates[0], dtype=np.float32),
271-
}
272-
obs["target_gate"] = self.target_gate
273265
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
274266
# use the actual pose, otherwise use the nominal pose.
275-
drone_pos = self.sim.data.states.pos[0]
276-
# Performance optimization: Get a continuous slice instead of using a list of indices which
277-
# copies the data. Assumes that the mocap ids are consecutive.
278-
gates_visited, gates_pos, gates_rpy = self.obs_acc_gates(
267+
gates_visited, gates_pos, gates_rpy = self._obs_gates(
279268
self.gates_visited,
280-
drone_pos,
269+
self.sim.data.states.pos[0],
281270
self.sim.data.mjx_data.mocap_pos[0],
282271
self.sim.data.mjx_data.mocap_quat[0],
283272
self.gates["mocap_ids"],
@@ -286,62 +275,66 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
286275
self.gates["nominal_rpy"],
287276
)
288277
self.gates_visited = np.asarray(gates_visited, dtype=bool)
289-
obs["gates_pos"] = np.asarray(gates_pos, dtype=np.float32)
290-
obs["gates_rpy"] = np.asarray(gates_rpy, dtype=np.float32)
291-
obs["gates_visited"] = self.gates_visited
292-
293-
obstacles_visited, obstacles_pos = self.obs_acc_obstacles(
278+
obstacles_visited, obstacles_pos = self._obs_obstacles(
294279
self.obstacles_visited,
295-
drone_pos,
280+
self.sim.data.states.pos[0],
296281
self.sim.data.mjx_data.mocap_pos[0],
297282
self.obstacles["mocap_ids"],
298283
self.sensor_range,
299284
self.obstacles["nominal_pos"],
300285
)
301286
self.obstacles_visited = np.asarray(obstacles_visited, dtype=bool)
302-
obs["obstacles_pos"] = np.asarray(obstacles_pos, dtype=np.float32)
303-
obs["obstacles_visited"] = self.obstacles_visited
304287
# TODO: Decide on observation disturbances
288+
obs = {
289+
"pos": np.array(self.sim.data.states.pos[0], dtype=np.float32),
290+
"rpy": R.from_quat(self.sim.data.states.quat[0]).as_euler("xyz").astype(np.float32),
291+
"vel": np.array(self.sim.data.states.vel[0], dtype=np.float32),
292+
"ang_vel": np.array(self.sim.data.states.rpy_rates[0], dtype=np.float32),
293+
"target_gate": self.target_gate,
294+
"gates_pos": np.asarray(gates_pos, dtype=np.float32),
295+
"gates_rpy": np.asarray(gates_rpy, dtype=np.float32),
296+
"gates_visited": self.gates_visited,
297+
"obstacles_pos": np.asarray(obstacles_pos, dtype=np.float32),
298+
"obstacles_visited": self.obstacles_visited,
299+
}
305300
return obs
306301

307302
@staticmethod
308303
@jax.jit
309-
def obs_acc_gates(
310-
gates_visited,
311-
drone_pos,
312-
mocap_pos,
313-
mocap_quat,
314-
mocap_ids,
315-
sensor_range,
316-
nominal_pos,
317-
nominal_rpy,
318-
):
319-
# TODO: Clean up the accelerated functions
320-
gates_pos = mocap_pos[mocap_ids]
321-
gates_quat = mocap_quat[mocap_ids][..., [1, 2, 3, 0]]
322-
gates_rpy = jax.scipy.spatial.transform.Rotation.from_quat(gates_quat).as_euler("xyz")
323-
dpos = drone_pos[..., None, :2] - gates_pos[:, :2]
304+
def _obs_gates(
305+
gates_visited: NDArray,
306+
drone_pos: Array,
307+
mocap_pos: Array,
308+
mocap_quat: Array,
309+
mocap_ids: NDArray,
310+
sensor_range: float,
311+
nominal_pos: NDArray,
312+
nominal_rpy: NDArray,
313+
) -> tuple[Array, Array, Array]:
314+
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
315+
real_quat = mocap_quat[mocap_ids][..., [1, 2, 3, 0]]
316+
real_rpy = jax.scipy.spatial.transform.Rotation.from_quat(real_quat).as_euler("xyz")
317+
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
324318
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
325319
gates_visited = jp.logical_or(gates_visited, in_range)
326-
327-
mask = gates_visited[..., None]
328-
gates_pos = jp.where(mask, gates_pos, nominal_pos)
329-
gates_rpy = jp.where(mask, gates_rpy, nominal_rpy)
320+
gates_pos = jp.where(gates_visited[..., None], mocap_pos[mocap_ids], nominal_pos)
321+
gates_rpy = jp.where(gates_visited[..., None], real_rpy, nominal_rpy)
330322
return gates_visited, gates_pos, gates_rpy
331323

332324
@staticmethod
333325
@jax.jit
334-
def obs_acc_obstacles(
335-
obstacles_visited, drone_pos, mocap_pos, mocap_ids, sensor_range, nominal_pos
336-
):
337-
# TODO: Clean up the accelerated functions
338-
obstacles_pos = mocap_pos[mocap_ids]
339-
dpos = drone_pos[..., None, :2] - obstacles_pos[:, :2]
326+
def _obs_obstacles(
327+
visited: NDArray,
328+
drone_pos: Array,
329+
mocap_pos: Array,
330+
mocap_ids: NDArray,
331+
sensor_range: float,
332+
nominal_pos: NDArray,
333+
) -> tuple[Array, Array]:
334+
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
340335
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
341-
obstacles_visited = jp.logical_or(obstacles_visited, in_range)
342-
mask = obstacles_visited[..., None]
343-
obstacles_pos = jp.where(mask, obstacles_pos, nominal_pos)
344-
return obstacles_visited, obstacles_pos
336+
visited = jp.logical_or(visited, in_range)
337+
return visited, jp.where(visited[..., None], mocap_pos[mocap_ids], nominal_pos)
345338

346339
def reward(self) -> float:
347340
"""Compute the reward for the current state.
@@ -368,33 +361,20 @@ def info(self) -> dict:
368361
"""Return an info dictionary containing additional information about the environment."""
369362
return {"collisions": np.any(self.sim.contacts(), axis=-1), "symbolic_model": self.symbolic}
370363

371-
def update_active_drones(self):
372-
# TODO: Accelerate
373-
pos = self.sim.data.states.pos[0, ...]
374-
rpy = R.from_quat(self.sim.data.states.quat[0, ...]).as_euler("xyz")
375-
disabled = np.logical_or(self.disabled_drones, np.all(pos < self.pos_bounds.low, axis=-1))
376-
disabled = np.logical_or(disabled, np.all(pos > self.pos_bounds.high, axis=-1))
377-
disabled = np.logical_or(disabled, np.all(rpy < self.rpy_bounds.low, axis=-1))
378-
disabled = np.logical_or(disabled, np.all(rpy > self.rpy_bounds.high, axis=-1))
379-
disabled = np.logical_or(disabled, self.target_gate == -1)
380-
contacts = np.any(np.logical_and(self.sim.contacts(), self.contact_masks), axis=-1)
381-
disabled = np.logical_or(disabled, contacts)
382-
self.disabled_drones = disabled
383-
384364
@staticmethod
385365
@jax.jit
386-
def update_active_drones_acc(
387-
pos,
388-
quat,
389-
pos_low,
390-
pos_high,
391-
rpy_low,
392-
rpy_high,
393-
target_gate,
394-
disabled_drones,
395-
contacts,
396-
contact_masks,
397-
):
366+
def _disabled_drones(
367+
pos: Array,
368+
quat: Array,
369+
pos_low: NDArray,
370+
pos_high: NDArray,
371+
rpy_low: NDArray,
372+
rpy_high: NDArray,
373+
target_gate: NDArray,
374+
disabled_drones: NDArray,
375+
contacts: Array,
376+
contact_masks: NDArray,
377+
) -> Array:
398378
rpy = jax.scipy.spatial.transform.Rotation.from_quat(quat).as_euler("xyz")
399379
disabled = jp.logical_or(disabled_drones, jp.all(pos < pos_low, axis=-1))
400380
disabled = jp.logical_or(disabled, jp.all(pos > pos_high, axis=-1))
@@ -481,30 +461,9 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
481461
mocap_ids = [int(mj_model.body(f"obstacle:{i}").mocapid) for i in range(n_obstacles)]
482462
obstacles["mocap_ids"] = np.array(mocap_ids, dtype=np.int32)
483463

484-
def gate_passed(self) -> bool:
485-
"""Check if the drone has passed a gate.
486-
487-
Returns:
488-
True if the drone has passed a gate, else False.
489-
"""
490-
passed = np.zeros(self.sim.n_drones, dtype=bool)
491-
if self.n_gates <= 0:
492-
return passed
493-
gate_ids = self.target_gate % self.n_gates
494-
gate_mj_id = self.gates["mocap_ids"][gate_ids]
495-
gate_pos = self.sim.data.mjx_data.mocap_pos[0, gate_mj_id].squeeze()
496-
gate_rot = R.from_quat(self.sim.data.mjx_data.mocap_quat[0, gate_mj_id], scalar_first=True)
497-
drone_pos = self.sim.data.states.pos[0]
498-
gate_size = (0.45, 0.45)
499-
for i in range(self.sim.n_drones):
500-
passed[i] = check_gate_pass(
501-
gate_pos[i], gate_rot[i], gate_size, drone_pos[i], self._last_drone_pos[i]
502-
)
503-
return passed
504-
505464
@staticmethod
506465
@jax.jit
507-
def gate_passed_accelerated(
466+
def _gate_passed(
508467
target_gate: NDArray,
509468
mocap_ids: NDArray,
510469
mocap_pos: Array,
@@ -518,18 +477,16 @@ def gate_passed_accelerated(
518477
Returns:
519478
True if the drone has passed a gate, else False.
520479
"""
521-
# TODO: Test, refactor, optimize. Cover cases with no gates.
522-
gate_ids = target_gate % n_gates
523-
gate_mj_id = mocap_ids[gate_ids]
524-
gate_pos = mocap_pos[gate_mj_id]
525-
gate_rot = jax.scipy.spatial.transform.Rotation.from_quat(
526-
mocap_quat[gate_mj_id][..., [1, 2, 3, 0]]
527-
)
480+
# TODO: Test. Cover cases with no gates.
481+
ids = mocap_ids[target_gate % n_gates]
482+
gate_pos = mocap_pos[ids]
483+
gate_quat = mocap_quat[ids][..., [1, 2, 3, 0]]
484+
gate_rot = jax.scipy.spatial.transform.Rotation.from_quat(gate_quat)
528485
gate_size = (0.45, 0.45)
529486
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
530487
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
531-
# Check the plane intersection. If passed, calculate the point of the intersection and check if
532-
# it is within the gate box.
488+
# Check if the line between the last position and the current position intersects the plane.
489+
# If so, calculate the point of the intersection and check if it is within the gate box.
533490
passed_plane = (last_pos_local[..., 1] < 0) & (pos_local[..., 1] > 0)
534491
alpha = -last_pos_local[..., 1] / (pos_local[..., 1] - last_pos_local[..., 1])
535492
x_intersect = alpha * (pos_local[..., 0]) + (1 - alpha) * last_pos_local[..., 0]
@@ -540,6 +497,7 @@ def gate_passed_accelerated(
540497
@staticmethod
541498
@jax.jit
542499
def warp_disabled_drones(data: SimData, mask: NDArray) -> SimData:
500+
"""Warp the disabled drones below the ground."""
543501
mask = mask.reshape((1, -1, 1))
544502
pos = jax.numpy.where(mask, -1, data.states.pos)
545503
return data.replace(states=data.states.replace(pos=pos))

0 commit comments

Comments
 (0)