Skip to content

Commit 8ccf9c9

Browse files
committed
Fix deploy script
1 parent ac53f07 commit 8ccf9c9

File tree

3 files changed

+123
-118
lines changed

3 files changed

+123
-118
lines changed

lsy_drone_racing/command.py

+2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def apply_command(cf: Crazyflie, command_type: Command, args: Any):
6969
args: Additional arguments as potentially required by `command_type`.
7070
"""
7171
if command_type == Command.FULLSTATE:
72+
# Sim version takes an additional 'ep_time' args that needs to be removed when deployed
73+
args = args[:5]
7274
cf.cmdFullState(*args)
7375
elif command_type == Command.TAKEOFF:
7476
cf.takeoff(*args)

lsy_drone_racing/vicon.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def __init__(self, track_names: list[str] = []):
2525
track_names: The names of any additional objects besides the drone to track.
2626
"""
2727
assert Master("/rosnode").is_online(), "ROS is not running. Please run hover.launch first!"
28-
rospy.init_node("playback_node")
28+
try:
29+
rospy.init_node("playback_node")
30+
except rospy.exceptions.ROSException:
31+
... # ROS node is already running which is fine for us
2932
config_path = get_ros_package_path("crazyswarm") / "launch/crazyflies.yaml"
3033
assert config_path.exists(), "Crazyfly config file missing!"
3134
with open(config_path, "r") as f:

scripts/deploy.py

+117-117
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,25 @@
99

1010
from __future__ import annotations
1111

12+
import logging
1213
import pickle
1314
import time
1415
from pathlib import Path
15-
from typing import Optional
1616

1717
import fire
1818
import numpy as np
19-
import rospkg
2019
import rospy
2120
import yaml
22-
from pycrazyswarm import Crazyswarm
2321
from safe_control_gym.utils.configuration import ConfigFactory
2422
from safe_control_gym.utils.registration import make
2523

26-
from lsy_drone_racing.command import apply_command
24+
from lsy_drone_racing.command import Command, apply_command
25+
from lsy_drone_racing.import_utils import get_ros_package_path, pycrazyswarm
2726
from lsy_drone_racing.utils import check_gate_pass, load_controller
2827
from lsy_drone_racing.vicon import ViconWatcher
2928

29+
logger = logging.getLogger(__name__)
30+
3031

3132
def create_init_info(
3233
env_info: dict, gate_poses: list, obstacle_poses: list, constraint_values: list
@@ -55,7 +56,7 @@ def create_init_info(
5556
"x_reference": [-0.5, 0.0, 2.9, 0.0, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
5657
"u_reference": [0.084623, 0.084623, 0.084623, 0.084623],
5758
"symbolic_constraints": env_info["symbolic_constraints"],
58-
"ctrl_timestep": 0.03333333333333333,
59+
"ctrl_timestep": 1 / 30,
5960
"ctrl_freq": 30,
6061
"episode_len_sec": 33,
6162
"quadrotor_kf": 3.16e-10,
@@ -75,62 +76,64 @@ def create_init_info(
7576
return init_info
7677

7778

78-
def main(controller_path: str, config: str = "", overrides: Optional[str] = None):
79+
def main(config: str = "config/getting_started.yaml", controller: str = "examples/controller.py"):
7980
"""Deployment script to run the controller on the real drone."""
8081
start_time = time.time()
8182

8283
# Load the controller and initialize the crazyfly interface
83-
Controller = load_controller(Path(controller_path))
84-
pkg = rospkg.RosPack()
85-
assert "crazyswarm" in pkg.list(), "Crazyswarm package not found. Did you source the workspace?"
86-
crazyswarm_path = Path(pkg.get_path("crazyswarm"))
87-
swarm = Crazyswarm(crazyswarm_path / "launch/crazyflies.yaml", args="--sim")
84+
Controller = load_controller(Path(controller))
85+
crazyswarm_config_path = get_ros_package_path("crazyswarm") / "launch/crazyflies.yaml"
86+
# pycrazyswarm expects strings, not Path objects, so we need to convert it first
87+
swarm = pycrazyswarm.Crazyswarm(str(crazyswarm_config_path))
8888
time_helper = swarm.timeHelper
8989
cf = swarm.allcfs.crazyflies[0]
9090

91-
vicon = ViconWatcher()
92-
# TODO: Replace with autodetection of gate and obstacle positions
93-
"""hi1 = ObjectWatcher("cf_hi1")
94-
hi2 = ObjectWatcher("cf_hi2")
95-
lo1 = ObjectWatcher("cf_lo1")
96-
lo2 = ObjectWatcher("cf_lo2")
97-
obs1 = ObjectWatcher("cf_obs1")
98-
obs2 = ObjectWatcher("cf_obs2")
99-
obs3 = ObjectWatcher("cf_obs3")
100-
obs4 = ObjectWatcher("cf_obs4")"""
101-
102-
timeout = 10.0
91+
vicon = ViconWatcher() # TODO: Integrate autodetection of gate and obstacle positions
92+
93+
timeout = 5.0
10394
tstart = time.time()
10495
while not vicon.active:
105-
print("Waiting for vicon...")
96+
logger.info("Waiting for vicon...")
10697
time.sleep(1)
10798
if time.time() - tstart > timeout:
10899
raise TimeoutError("Vicon unavailable.")
109100

110-
# TODO: Replace config with autodetection of crazyswarm ROS package
111101
config_path = Path(config).resolve()
112102
assert config_path.is_file(), "Config file does not exist!"
113103
with open(config_path, "r") as f:
114-
config = yaml.load(f)
104+
config = yaml.load(f, yaml.SafeLoader)
105+
config_factory = ConfigFactory()
106+
config_factory.base_dict = config
107+
config = config_factory.merge()
108+
109+
# Check if the real drone position matches the settings
110+
tol = 0.1
111+
init_state = config.quadrotor_config.init_state
112+
drone_pos = np.array([init_state[key] for key in ("init_x", "init_y", "init_z")])
113+
if d := np.linalg.norm(drone_pos - vicon.pos["cf"]) > tol:
114+
raise RuntimeError(
115+
(
116+
f"Distance between drone and starting position too great ({d:.2f}m)"
117+
f"Position is {vicon.pos['cf']}, should be {drone_pos}"
118+
)
119+
)
115120

116121
# TODO: Replace with autodetection of gate and obstacle positions
117122
# TODO: Change obstacle and gate definitions to freely adjust the height
118-
gate_poses = config["gates_pos_and_type"]
123+
gate_poses = config.quadrotor_config.gates
119124
for gate in gate_poses:
120125
if gate[3] != 0 or gate[4] != 0:
121126
raise ValueError("Gates can't have roll or pitch!")
122-
obstacle_poses = config["obstacles_pos"]
127+
obstacle_poses = config.quadrotor_config.obstacles
123128

124129
# Create a safe-control-gym environment from which to take the symbolic models
125-
CONFIG_FACTORY = ConfigFactory()
126-
config = CONFIG_FACTORY.merge()
127130
config.quadrotor_config["ctrl_freq"] = 500
128131
env = make("quadrotor", **config.quadrotor_config)
129132
_, env_info = env.reset()
130133

131134
# Override environment state and evaluate constraints
132-
drone_pos_and_vel = [vicon.pos[0], 0, vicon.pos[1], 0, vicon.pos[2], 0]
133-
drone_rot_and_agl_vel = [vicon.rpy[0], vicon.rpy[1], vicon.rpy[2], 0, 0, 0]
135+
drone_pos_and_vel = [vicon.pos["cf"][0], 0, vicon.pos["cf"][1], 0, vicon.pos["cf"][2], 0]
136+
drone_rot_and_agl_vel = [vicon.rpy["cf"][0], vicon.rpy["cf"][1], vicon.rpy["cf"][2], 0, 0, 0]
134137
env.state = drone_pos_and_vel + drone_rot_and_agl_vel
135138
constraint_values = env.constraints.get_values(env, only_state=True)
136139

@@ -145,97 +148,94 @@ def main(controller_path: str, config: str = "", overrides: Optional[str] = None
145148
# Helper parameters
146149
target_gate_id = 0 # Initial gate.
147150
log_cmd = [] # Log commands as [current time, ros time, command type, args]
148-
last_drone_pos = vicon.pos.copy() # Helper for determining if the drone has crossed a goal
151+
last_drone_pos = vicon.pos["cf"].copy() # Gate crossing helper
149152
completed = False
150153
print(f"Setup time: {time.time() - start_time:.3}s")
151154

152-
# Run the main control loop
153-
start_time = time.time()
154-
while not time_helper.isShutdown():
155-
curr_time = time.time() - start_time
156-
157-
# Override environment state and evaluate constraints
158-
env.state = [
159-
vicon.pos[0],
160-
0,
161-
vicon.pos[1],
162-
0,
163-
vicon.pos[2],
164-
0,
165-
vicon.rpy[0],
166-
vicon.rpy[1],
167-
vicon.rpy[2],
168-
0,
169-
0,
170-
0,
171-
]
172-
state_error = (env.state - env.X_GOAL) * env.info_mse_metric_state_weight
173-
constraint_values = env.constraints.get_values(env, only_state=True)
174-
# IROS 2022 - Constrain violation flag for reward.
175-
env.cnstr_violation = env.constraints.is_violated(env, c_value=constraint_values)
176-
cnstr_num = 1 if env.cnstr_violation else 0
177-
178-
# This only looks at the x-y plane, could be improved
179-
# TODO: Replace with 3D distance once gate poses are given with height
180-
gate_dist = np.sqrt(np.sum((vicon.pos[0:2] - gate_poses[target_gate_id][0:2]) ** 2))
181-
info = {
182-
"mse": np.sum(state_error**2),
183-
"collision": (None, False), # Leave always false in sim2real
184-
"current_target_gate_id": target_gate_id,
185-
"current_target_gate_in_range": gate_dist < 0.45,
186-
"current_target_gate_pos": gate_poses[target_gate_id][0:6], # Always "exact"
187-
"current_target_gate_type": gate_poses[target_gate_id][6],
188-
"at_goal_position": False, # Leave always false in sim2real
189-
"task_completed": False, # Leave always false in sim2real
190-
"constraint_values": constraint_values,
191-
"constraint_violation": cnstr_num,
192-
}
193-
194-
# Check if the drone has passed the current gate
195-
if check_gate_pass(gate_poses[target_gate_id], vicon.pos, last_drone_pos):
196-
target_gate_id += 1
197-
print(f"Gate {target_gate_id} passed in {curr_time:.4}s")
198-
last_drone_pos = vicon.pos.copy()
199-
200-
if target_gate_id == len(gate_poses): # Reached the end
201-
target_gate_id = -1
202-
at_goal_time = time.time()
203-
204-
if target_gate_id == -1:
205-
goal_pos = np.array([env.X_GOAL[0], env.X_GOAL[2], env.X_GOAL[4]])
206-
goal_dist = np.linalg.norm(vicon.pos[0:3] - goal_pos)
207-
print(f"{time.time() - at_goal_time:.4}s and {goal_dist}m away")
208-
if goal_dist >= 0.15:
209-
print(f"First hit goal position in {curr_time:.4}s")
155+
try:
156+
# Run the main control loop
157+
start_time = time.time()
158+
while not time_helper.isShutdown():
159+
curr_time = time.time() - start_time
160+
161+
# Override environment state and evaluate constraints
162+
p, r = vicon.pos["cf"], vicon.rpy["cf"]
163+
env.state = [p[0], 0, p[1], 0, p[2], 0, r[0], r[1], r[2], 0, 0, 0]
164+
state_error = (env.state - env.X_GOAL) * env.info_mse_metric_state_weight
165+
constraint_values = env.constraints.get_values(env, only_state=True)
166+
# IROS 2022 - Constrain violation flag for reward.
167+
env.cnstr_violation = env.constraints.is_violated(env, c_value=constraint_values)
168+
cnstr_num = 1 if env.cnstr_violation else 0
169+
170+
p = vicon.pos["cf"]
171+
# This only looks at the x-y plane, could be improved
172+
# TODO: Replace with 3D distance once gate poses are given with height
173+
gate_dist = np.sqrt(np.sum((p[0:2] - gate_poses[target_gate_id][0:2]) ** 2))
174+
info = {
175+
"mse": np.sum(state_error**2),
176+
"collision": (None, False), # Leave always false in sim2real
177+
"current_target_gate_id": target_gate_id,
178+
"current_target_gate_in_range": gate_dist < 0.45,
179+
"current_target_gate_pos": gate_poses[target_gate_id][0:6], # Always "exact"
180+
"current_target_gate_type": gate_poses[target_gate_id][6],
181+
"at_goal_position": False, # Leave always false in sim2real
182+
"task_completed": False, # Leave always false in sim2real
183+
"constraint_values": constraint_values,
184+
"constraint_violation": cnstr_num,
185+
}
186+
187+
# Check if the drone has passed the current gate
188+
if check_gate_pass(gate_poses[target_gate_id], vicon.pos["cf"], last_drone_pos):
189+
target_gate_id += 1
190+
print(f"Gate {target_gate_id} passed in {curr_time:.4}s")
191+
last_drone_pos = vicon.pos["cf"].copy()
192+
193+
if target_gate_id == len(gate_poses): # Reached the end
194+
target_gate_id = -1
210195
at_goal_time = time.time()
211-
elif time.time() - at_goal_time > 2:
212-
print(f"Task Completed in {curr_time:.4}s")
213-
completed = True
214-
215-
# Get the latest vicon observation and call the controller
216-
drone_pos_and_vel = [vicon.pos[0], 0, vicon.pos[1], 0, vicon.pos[2], 0]
217-
drone_rot_and_agl_vel = [vicon.rpy[0], vicon.rpy[1], vicon.rpy[2], 0, 0, 0]
218-
vicon_obs = drone_pos_and_vel + drone_rot_and_agl_vel
219-
# In sim2real: Reward always 0, done always false
220-
command_type, args = ctrl.compute_control(curr_time, vicon_obs, 0, False, info)
221-
log_cmd.append([curr_time, rospy.get_time(), command_type, args]) # Save cmd for logging
222-
223-
apply_command(cf, command_type, args) # Send the command to the drone controller
224-
time_helper.sleepForRate(CTRL_FREQ)
225-
226-
if completed:
227-
break
228-
229-
# Land does not wait for the drone to actually land, so we have to wait manually
230-
cf.land(0, 3)
231-
time_helper.sleep(3.5)
232196

233-
# Save the commands for logging
234-
save_dir = Path(__file__).resolve().parents[1] / "logs"
235-
save_dir.mkdir(parents=True, exist_ok=True)
236-
with open(save_dir / "log.pkl", "wb") as f:
237-
pickle.dump(log_cmd, f)
197+
# TODO: Clean this up
198+
if target_gate_id == -1:
199+
goal_pos = np.array([env.X_GOAL[0], env.X_GOAL[2], env.X_GOAL[4]])
200+
goal_dist = np.linalg.norm(vicon.pos["cf"][0:3] - goal_pos)
201+
print(f"{time.time() - at_goal_time:.4}s and {goal_dist}m away")
202+
if goal_dist >= 0.15:
203+
print(f"First hit goal position in {curr_time:.4}s")
204+
at_goal_time = time.time()
205+
elif time.time() - at_goal_time > 2:
206+
print(f"Task Completed in {curr_time:.4}s")
207+
completed = True
208+
209+
# Get the latest vicon observation and call the controller
210+
p = vicon.pos["cf"]
211+
drone_pos_and_vel = [p[0], 0, p[1], 0, p[2], 0]
212+
r = vicon.rpy["cf"]
213+
drone_rot_and_agl_vel = [r[0], r[1], r[2], 0, 0, 0]
214+
vicon_obs = drone_pos_and_vel + drone_rot_and_agl_vel
215+
# In sim2real: Reward always 0, done always false
216+
command_type, args = ctrl.compute_control(curr_time, vicon_obs, 0, False, info)
217+
log_cmd.append([curr_time, rospy.get_time(), command_type, args]) # Save for logging
218+
219+
apply_command(cf, command_type, args) # Send the command to the drone controller
220+
time_helper.sleepForRate(CTRL_FREQ)
221+
222+
if completed:
223+
break
224+
225+
# Land does not wait for the drone to actually land, so we have to wait manually
226+
cf.land(0, 3)
227+
time_helper.sleep(3.5)
228+
229+
# Save the commands for logging
230+
save_dir = Path(__file__).resolve().parents[1] / "logs"
231+
save_dir.mkdir(parents=True, exist_ok=True)
232+
with open(save_dir / "log.pkl", "wb") as f:
233+
pickle.dump(log_cmd, f)
234+
finally:
235+
apply_command(cf, Command.NOTIFYSETPOINTSTOP, [])
236+
apply_command(cf, Command.LAND, [0.0, 2.0]) # Args are height and duration
238237

239238

240239
if __name__ == "__main__":
240+
logging.basicConfig(level=logging.INFO)
241241
fire.Fire(main)

0 commit comments

Comments
 (0)