Skip to content

Commit a73f377

Browse files
committed
Optimize performance. Fix gate detection
1 parent 423979b commit a73f377

File tree

2 files changed

+15
-16
lines changed

2 files changed

+15
-16
lines changed

lsy_drone_racing/envs/drone_racing_env.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,6 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
221221
"vel": np.array(self.sim.data.states.vel[0, 0], dtype=np.float32),
222222
"ang_vel": np.array(self.sim.data.states.rpy_rates[0, 0], dtype=np.float32),
223223
}
224-
obs["ang_vel"][:] = R.from_euler("xyz", obs["rpy"]).apply(obs["ang_vel"], inverse=True)
225-
226224
obs["target_gate"] = self.target_gate if self.target_gate < len(self.gates) else -1
227225
# Add the gate and obstacle poses to the info. If gates or obstacles are in sensor range,
228226
# use the actual pose, otherwise use the nominal pose.
@@ -247,9 +245,7 @@ def obs(self) -> dict[str, NDArray[np.floating]]:
247245
obstacles_pos[self.obstacles_visited] = self.obstacles["pos"][self.obstacles_visited]
248246
obs["obstacles_pos"] = obstacles_pos.astype(np.float32)
249247
obs["obstacles_visited"] = self.obstacles_visited
250-
251-
if "observation" in self.disturbances:
252-
obs = self.disturbances["observation"].apply(obs)
248+
# TODO: Observation disturbances?
253249
return obs
254250

255251
def reward(self) -> float:
@@ -370,10 +366,13 @@ def _load_track_into_sim(self, gates: dict, obstacles: dict):
370366
assert not hasattr(self.sim.data, "gate_pos")
371367
assert not hasattr(self.sim.data, "obstacle_pos")
372368

373-
gate_ids = [self.sim.mj_model.body(f"gate:{i}").id for i in range(n_gates)]
374-
gates["ids"] = gate_ids
375-
obstacle_ids = [self.sim.mj_model.body(f"obstacle:{i}").id for i in range(n_obstacles)]
376-
obstacles["ids"] = obstacle_ids
369+
mj_model = self.sim.mj_model
370+
gates["ids"] = [mj_model.body(f"gate:{i}").id for i in range(n_gates)]
371+
gates["mocap_ids"] = [int(mj_model.body(f"gate:{i}").mocapid) for i in range(n_gates)]
372+
obstacles["ids"] = [mj_model.body(f"obstacle:{i}").id for i in range(n_obstacles)]
373+
obstacles["mocap_ids"] = [
374+
int(mj_model.body(f"obstacle:{i}").mocapid) for i in range(n_obstacles)
375+
]
377376

378377
def gate_passed(self) -> bool:
379378
"""Check if the drone has passed a gate.
@@ -383,12 +382,12 @@ def gate_passed(self) -> bool:
383382
"""
384383
if self.n_gates <= 0 or self.target_gate >= self.n_gates or self.target_gate == -1:
385384
return False
386-
gate_id = self.gates["ids"][self.target_gate]
387-
gate_pos = self.sim.data.mjx_data.mocap_pos[0, gate_id]
388-
gate_quat = self.sim.data.mjx_data.mocap_quat[0, gate_id][..., [3, 0, 1, 2]]
385+
gate_mj_id = self.gates["mocap_ids"][self.target_gate]
386+
gate_pos = self.sim.data.mjx_data.mocap_pos[0, gate_mj_id].squeeze()
387+
gate_rot = R.from_quat(self.sim.data.mjx_data.mocap_quat[0, gate_mj_id], scalar_first=True)
389388
drone_pos = self.sim.data.states.pos[0, 0]
390389
gate_size = (0.45, 0.45)
391-
return check_gate_pass(gate_pos, gate_quat, gate_size, drone_pos, self._last_drone_pos)
390+
return check_gate_pass(gate_pos, gate_rot, gate_size, drone_pos, self._last_drone_pos)
392391

393392
def close(self):
394393
"""Close the environment by stopping the drone and landing back at the starting position."""

lsy_drone_racing/utils/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def load_config(path: Path) -> Munch:
9191

9292
def check_gate_pass(
9393
gate_pos: np.ndarray,
94-
gate_quat: np.ndarray,
94+
gate_rot: R,
9595
gate_size: np.ndarray,
9696
drone_pos: np.ndarray,
9797
last_drone_pos: np.ndarray,
@@ -111,13 +111,13 @@ def check_gate_pass(
111111
112112
Args:
113113
gate_pos: The position of the gate in the world frame.
114-
gate_quat: The quaternion of the gate in the world frame.
114+
gate_rot: The rotation of the gate.
115115
gate_size: The size of the gate box in meters.
116116
drone_pos: The position of the drone in the world frame.
117117
last_drone_pos: The position of the drone in the world frame at the last time step.
118118
"""
119119
# Transform last and current drone position into current gate frame.
120-
gate_rot = R(gate_quat, normalize=False, copy=False)
120+
assert isinstance(gate_rot, R), "gate_rot has to be a Rotation object."
121121
last_pos_local = gate_rot.apply(last_drone_pos - gate_pos, inverse=True)
122122
pos_local = gate_rot.apply(drone_pos - gate_pos, inverse=True)
123123
# Check the plane intersection. If passed, calculate the point of the intersection and check if

0 commit comments

Comments
 (0)