Skip to content

Commit 26141b0

Browse files
committed
Small renamings
1 parent 4e81ac7 commit 26141b0

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

lsy_drone_racing/envs/multi_drone_race.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from crazyflow import Sim
2020
from crazyflow.sim.symbolic import symbolic_attitude
2121
from gymnasium import spaces
22+
from jax.scipy.spatial.transform import Rotation as JaxR
2223
from scipy.spatial.transform import Rotation as R
2324

2425
from lsy_drone_racing.envs.randomize import (
@@ -302,7 +303,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
302303
@staticmethod
303304
@jax.jit
304305
def _obs_gates(
305-
gates_visited: NDArray,
306+
visited: NDArray,
306307
drone_pos: Array,
307308
mocap_pos: Array,
308309
mocap_quat: Array,
@@ -312,14 +313,12 @@ def _obs_gates(
312313
nominal_rpy: NDArray,
313314
) -> tuple[Array, Array, Array]:
314315
"""Get the nominal or real gate positions and orientations depending on the sensor range."""
315-
real_quat = mocap_quat[mocap_ids][..., [1, 2, 3, 0]]
316-
real_rpy = jax.scipy.spatial.transform.Rotation.from_quat(real_quat).as_euler("xyz")
316+
real_rpy = JaxR.from_quat(mocap_quat[mocap_ids][..., [1, 2, 3, 0]]).as_euler("xyz")
317317
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
318-
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
319-
gates_visited = jp.logical_or(gates_visited, in_range)
320-
gates_pos = jp.where(gates_visited[..., None], mocap_pos[mocap_ids], nominal_pos)
321-
gates_rpy = jp.where(gates_visited[..., None], real_rpy, nominal_rpy)
322-
return gates_visited, gates_pos, gates_rpy
318+
visited = jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)
319+
gates_pos = jp.where(visited[..., None], mocap_pos[mocap_ids], nominal_pos)
320+
gates_rpy = jp.where(visited[..., None], real_rpy, nominal_rpy)
321+
return visited, gates_pos, gates_rpy
323322

324323
@staticmethod
325324
@jax.jit
@@ -332,8 +331,7 @@ def _obs_obstacles(
332331
nominal_pos: NDArray,
333332
) -> tuple[Array, Array]:
334333
dpos = drone_pos[..., None, :2] - mocap_pos[mocap_ids, :2]
335-
in_range = jp.linalg.norm(dpos, axis=-1) < sensor_range
336-
visited = jp.logical_or(visited, in_range)
334+
visited = jp.logical_or(visited, jp.linalg.norm(dpos, axis=-1) < sensor_range)
337335
return visited, jp.where(visited[..., None], mocap_pos[mocap_ids], nominal_pos)
338336

339337
def reward(self) -> float:
@@ -375,7 +373,7 @@ def _disabled_drones(
375373
contacts: Array,
376374
contact_masks: NDArray,
377375
) -> Array:
378-
rpy = jax.scipy.spatial.transform.Rotation.from_quat(quat).as_euler("xyz")
376+
rpy = JaxR.from_quat(quat).as_euler("xyz")
379377
disabled = jp.logical_or(disabled_drones, jp.all(pos < pos_low, axis=-1))
380378
disabled = jp.logical_or(disabled, jp.all(pos > pos_high, axis=-1))
381379
disabled = jp.logical_or(disabled, jp.all(rpy < rpy_low, axis=-1))
@@ -480,8 +478,7 @@ def _gate_passed(
480478
# TODO: Test. Cover cases with no gates.
481479
ids = mocap_ids[target_gate % n_gates]
482480
gate_pos = mocap_pos[ids]
483-
gate_quat = mocap_quat[ids][..., [1, 2, 3, 0]]
484-
gate_rot = jax.scipy.spatial.transform.Rotation.from_quat(gate_quat)
481+
gate_rot = JaxR.from_quat(mocap_quat[ids][..., [1, 2, 3, 0]])
485482
gate_size = (0.45, 0.45)
486483
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
487484
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)

0 commit comments

Comments
 (0)