Skip to content

Commit ab59746

Browse files
committed
Update Vicon to track gates.
Use exact gate position during deployment when sufficiently close. [TODO: Test on real setup]
1 parent bfd654b commit ab59746

File tree

3 files changed

+85
-36
lines changed

3 files changed

+85
-36
lines changed

lsy_drone_racing/utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@ def euler_from_quaternion(x: float, y: float, z: float, w: float) -> tuple[float
3737
return roll_x, pitch_y, yaw_z # in radians
3838

3939

40+
def map2pi(angle: np.ndarray) -> np.ndarray:
41+
"""Map an angle or array of angles to the interval of [-pi, pi].
42+
43+
Args:
44+
angle: Number or array of numbers.
45+
46+
Returns:
47+
The remapped angles.
48+
"""
49+
return ((angle + np.pi) % (2 * np.pi)) - np.pi
50+
51+
4052
def load_controller(path: Path) -> Type[BaseController]:
4153
"""Load the controller module from the given path and return the Controller class.
4254

lsy_drone_racing/vicon.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,66 @@
11
from __future__ import annotations
22

3+
import time
4+
35
import numpy as np
46
import rospy
57
import yaml
68
from rosgraph import Master
79
from tf2_msgs.msg import TFMessage
810

911
from lsy_drone_racing.import_utils import get_ros_package_path
10-
from lsy_drone_racing.utils import euler_from_quaternion
12+
from lsy_drone_racing.utils import euler_from_quaternion, map2pi
1113

1214

13-
class ViconWatcher:
15+
class Vicon:
1416
"""Vicon interface for the pose estimation data for the drone and any other tracked objects.
1517
1618
Vicon sends a stream of ROS messages containing the current pose data. We subscribe to these
1719
messages and save the pose data for each object in dictionaries. Users can then retrieve the
1820
latest pose data directly from these dictionaries.
1921
"""
2022

21-
def __init__(self, track_names: list[str] = []):
23+
def __init__(
24+
self, track_names: list[str] = [], auto_track_drone: bool = True, timeout: float = 0.0
25+
):
2226
"""Load the crazyflies.yaml file and register the subscribers for the Vicon pose data.
2327
2428
Args:
2529
track_names: The names of any additional objects besides the drone to track.
30+
auto_track_drone: Infer the drone name and add it to the positions if True.
31+
timeout: If greater than 0, Vicon waits for position updates of all tracked objects
32+
before returning.
2633
"""
2734
assert Master("/rosnode").is_online(), "ROS is not running. Please run hover.launch first!"
2835
try:
2936
rospy.init_node("playback_node")
3037
except rospy.exceptions.ROSException:
3138
... # ROS node is already running which is fine for us
32-
config_path = get_ros_package_path("crazyswarm") / "launch/crazyflies.yaml"
33-
assert config_path.exists(), "Crazyfly config file missing!"
34-
with open(config_path, "r") as f:
35-
config = yaml.load(f, yaml.SafeLoader)
36-
assert len(config["crazyflies"]) == 1, "Only one crazyfly allowed at a time!"
37-
self.drone_name = f"cf{config['crazyflies'][0]['id']}"
38-
39+
self.drone_name = None
40+
if auto_track_drone:
41+
with open(get_ros_package_path("crazyswarm") / "launch/crazyflies.yaml", "r") as f:
42+
config = yaml.load(f, yaml.SafeLoader)
43+
assert len(config["crazyflies"]) == 1, "Only one crazyfly allowed at a time!"
44+
self.drone_name = f"cf{config['crazyflies'][0]['id']}"
45+
track_names.insert(0, self.drone_name)
46+
self.track_names = track_names
3947
# Register the Vicon subscribers for the drone and any other tracked object
40-
self.pos: dict[str, np.ndarray] = {"cf": np.array([])}
41-
self.rpy: dict[str, np.ndarray] = {"cf": np.array([])}
42-
for track_name in track_names: # Initialize the objects' pose
43-
self.pos[track_name], self.rpy[track_name] = np.array([]), np.array([])
48+
self.pos: dict[str, np.ndarray] = {}
49+
self.rpy: dict[str, np.ndarray] = {}
50+
self.vel: dict[str, np.ndarray] = {}
51+
self.ang_vel: dict[str, np.ndarray] = {}
52+
self.time: dict[str, float] = {}
4453

4554
self.sub = rospy.Subscriber("/tf", TFMessage, self.save_pose)
55+
if timeout:
56+
tstart = time.time()
57+
while not self.active and time.time() - tstart < timeout:
58+
time.sleep(0.01)
59+
if not self.active:
60+
raise TimeoutError(
61+
"Timeout while fetching initial position updates for all tracked objects."
62+
f"Missing objects: {[k for k in self.track_names if k not in self.ang_vel]}"
63+
)
4664

4765
def save_pose(self, data: TFMessage):
4866
"""Save the position and orientation of all transforms.
@@ -51,10 +69,18 @@ def save_pose(self, data: TFMessage):
5169
data: The TF message containing the objects' pose.
5270
"""
5371
for tf in data.transforms:
54-
name = "cf" if tf.child_frame_id == self.drone_name else tf.child_frame_id
72+
name = tf.child_frame_id.split("/")[-1]
73+
if name not in self.pos:
74+
continue
5575
T, R = tf.transform.translation, tf.transform.rotation
56-
self.pos[name] = np.array([T.x, T.y, T.z])
57-
self.rpy[name] = np.array(euler_from_quaternion(R.x, R.y, R.z, R.w))
76+
pos = np.array([T.x, T.y, T.z])
77+
rpy = np.array(euler_from_quaternion(R.x, R.y, R.z, R.w))
78+
if self.pos[name]:
79+
self.vel[name] = (pos - self.pos[name]) / (time.time() - self.time[name])
80+
self.ang_vel[name] = map2pi(rpy - self.rpy[name]) / (time.time() - self.time[name])
81+
self.time[name] = time.time()
82+
self.pos[name] = pos
83+
self.rpy[name] = rpy
5884

5985
def pose(self, name: str) -> tuple[np.ndarray, np.ndarray]:
6086
"""Get the latest pose of a tracked object.
@@ -67,7 +93,17 @@ def pose(self, name: str) -> tuple[np.ndarray, np.ndarray]:
6793
"""
6894
return self.pos[name], self.rpy[name]
6995

96+
@property
97+
def poses(self) -> tuple[np.ndarray, np.ndarray]:
98+
"""Get the latest poses of all objects."""
99+
return np.stack(self.pos.values()), np.stack(self.rpy.values())
100+
101+
@property
102+
def names(self) -> list[str]:
103+
"""Get a list of actively tracked names."""
104+
return list(self.pos.keys())
105+
70106
@property
71107
def active(self) -> bool:
72108
"""Check if Vicon has sent data for each object."""
73-
return all(p.size > 0 for p in self.pos.values())
109+
return all([name in self.ang_vel for name in self.track_names])

scripts/deploy.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lsy_drone_racing.command import Command, apply_command
2525
from lsy_drone_racing.import_utils import get_ros_package_path, pycrazyswarm
2626
from lsy_drone_racing.utils import check_gate_pass, load_controller
27-
from lsy_drone_racing.vicon import ViconWatcher
27+
from lsy_drone_racing.vicon import Vicon
2828

2929
logger = logging.getLogger(__name__)
3030

@@ -93,15 +93,9 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
9393
time_helper = swarm.timeHelper
9494
cf = swarm.allcfs.crazyflies[0]
9595

96-
vicon = ViconWatcher() # TODO: Integrate autodetection of gate and obstacle positions
97-
98-
timeout = 5.0
99-
tstart = time.time()
100-
while not vicon.active:
101-
logger.info("Waiting for vicon...")
102-
time.sleep(1)
103-
if time.time() - tstart > timeout:
104-
raise TimeoutError("Vicon unavailable.")
96+
gate_names = [f"gate{i}" for i in range(1, len(config.quadrotor_config.gates) + 1)]
97+
obstacle_names = [f"obstacle{i}" for i in range(1, len(config.quadrotor_config.obstacles) + 1)]
98+
vicon = Vicon(track_names=gate_names + obstacle_names, timeout=1.0)
10599

106100
config_path = Path(config).resolve()
107101
assert config_path.is_file(), "Config file does not exist!"
@@ -137,8 +131,9 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
137131
_, env_info = env.reset()
138132

139133
# Override environment state and evaluate constraints
140-
drone_pos_and_vel = [vicon.pos["cf"][0], 0, vicon.pos["cf"][1], 0, vicon.pos["cf"][2], 0]
141-
drone_rot_and_agl_vel = [vicon.rpy["cf"][0], vicon.rpy["cf"][1], vicon.rpy["cf"][2], 0, 0, 0]
134+
drone_pos, drone_rot = vicon.pos[vicon.drone_name], vicon.rpy[vicon.drone_name]
135+
drone_pos_and_vel = [drone_pos[0], 0, drone_pos[1], 0, drone_pos[2], 0]
136+
drone_rot_and_agl_vel = [drone_rot[0], drone_rot[1], drone_rot[2], 0, 0, 0]
142137
env.state = drone_pos_and_vel + drone_rot_and_agl_vel
143138
constraint_values = env.constraints.get_values(env, only_state=True)
144139
x_reference = config.quadrotor_config.task_info.stabilization_goal
@@ -179,13 +174,17 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
179174
p = vicon.pos["cf"]
180175
# This only looks at the x-y plane, could be improved
181176
# TODO: Replace with 3D distance once gate poses are given with height
182-
gate_dist = np.sqrt(np.sum((p[0:2] - gate_poses[target_gate_id][0:2]) ** 2))
177+
gate_dist = np.sqrt(np.sum((p[0:2] - vicon.pos[gate_names[target_gate_id]][0:2]) ** 2))
178+
if gate_dist < 0.45:
179+
current_target_gate_pos = vicon.pos[gate_names[target_gate_id]]
180+
else:
181+
current_target_gate_pos = gate_poses[target_gate_id][0:6]
183182
info = {
184183
"mse": np.sum(state_error**2),
185184
"collision": (None, False), # Leave always false in sim2real
186185
"current_target_gate_id": target_gate_id,
187186
"current_target_gate_in_range": gate_dist < 0.45,
188-
"current_target_gate_pos": gate_poses[target_gate_id][0:6], # Always "exact"
187+
"current_target_gate_pos": current_target_gate_pos,
189188
"current_target_gate_type": gate_poses[target_gate_id][6],
190189
"at_goal_position": False, # Leave always false in sim2real
191190
"task_completed": False, # Leave always false in sim2real
@@ -194,19 +193,21 @@ def main(config: str = "config/getting_started.yaml", controller: str = "example
194193
}
195194

196195
# Check if the drone has passed the current gate
197-
if check_gate_pass(gate_poses[target_gate_id], vicon.pos["cf"], last_drone_pos):
196+
if check_gate_pass(
197+
gate_poses[target_gate_id], vicon.pos[vicon.drone_name], last_drone_pos
198+
):
198199
target_gate_id += 1
199200
print(f"Gate {target_gate_id} passed in {curr_time:.4}s")
200-
last_drone_pos = vicon.pos["cf"].copy()
201+
last_drone_pos = vicon.pos[vicon.drone_name].copy()
201202

202203
if target_gate_id == len(gate_poses): # Reached the end
203204
target_gate_id = -1
204205
total_time = time.time() - start_time
205206

206207
# Get the latest vicon observation and call the controller
207-
p = vicon.pos["cf"]
208+
p = vicon.pos[vicon.drone_name]
208209
drone_pos_and_vel = [p[0], 0, p[1], 0, p[2], 0]
209-
r = vicon.rpy["cf"]
210+
r = vicon.rpy[vicon.drone_name]
210211
drone_rot_and_agl_vel = [r[0], r[1], r[2], 0, 0, 0]
211212
vicon_obs = drone_pos_and_vel + drone_rot_and_agl_vel
212213
# In sim2real: Reward always 0, done always false

0 commit comments

Comments
 (0)