Skip to content

[Port HIL-SERL] Misc updates for RLPD on real robot #714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 85 additions & 44 deletions lerobot/scripts/eval_on_robot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,37 @@
import argparse
import logging
import time
from torchvision import transforms
from copy import deepcopy


import cv2
import numpy as np
import torch
from tqdm import trange
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset


from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position, predict_action
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
init_logging,
log_say,
)


def get_classifier(pretrained_path, config_path):
def manual_reset_follower_position(robot, manual_reset_time_s):
timestamp = 0
start_vencod_t = time.perf_counter()
# Wait if necessary
pbar = trange(total=manual_reset_time_s, desc="Manual reset")
while timestamp < manual_reset_time_s:
timestamp = time.perf_counter() - start_vencod_t
_ = robot.teleop_step(record_data=False)
pbar.update(1)

def get_classifier(pretrained_path, config_path, device):
if pretrained_path is None or config_path is None:
return

Expand All @@ -69,7 +83,7 @@ def get_classifier(pretrained_path, config_path):
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
model = model.to(device)
return model


Expand All @@ -81,6 +95,9 @@ def rollout(
control_time_s: float = 20,
use_amp: bool = True,
display_cameras: bool = False,
manual_reset_time_s: float | None = None,
image_transforms: transforms.Compose | None = None,
online_dataset: LeRobotDataset | None = None,
) -> dict:
"""Run a batched policy rollout on the real robot.

Expand All @@ -105,79 +122,101 @@ def rollout(
Returns:
The dictionary described above.
"""
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
# device = get_device_from_parameters(policy)

# define keyboard listener
listener, events = init_keyboard_listener()

# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset()

# NOTE: sorting to make sure the key sequence is the same during training and testing.

# Initialize
observation = robot.capture_observation()
image_keys = [key for key in observation if "image" in key]
image_keys.sort()

image_keys = sorted([key for key in observation if "image" in key])
init_pos = robot.follower_arms["main"].read("Present_Position")
timestamp = 0.0

all_observations = []
all_actions = []
all_rewards = []
all_successes = []
online_dataset_frames = []

start_episode_t = time.perf_counter()
init_pos = robot.follower_arms["main"].read("Present_Position")
timestamp = 0.0
while timestamp < control_time_s:
start_loop_t = time.perf_counter()

# Apply the next action.
while events["pause_policy"] and not events["human_intervention_step"]:
busy_wait(0.5)

if events["human_intervention_step"]:
# take over the robot's actions
observation, action = robot.teleop_step(record_data=True)
action = action["action"] # teleop step returns torch tensors but in a dict
action = action["action"]
else:
# explore with policy
with torch.inference_mode():
# TODO (michel-aractingi) replace this part with policy (predict_action)
action = robot.follower_arms["main"].read("Present_Position")
action = torch.from_numpy(action)
observation = robot.capture_observation()
action = predict_action(observation, policy, policy.device, use_amp=False)
robot.send_action(action)
# action = predict_action(observation, policy, device, use_amp)

observation = robot.capture_observation()
images = []
all_observations.append(observation)
if online_dataset is not None:
intervention = 1 if events["human_intervention_step"] else 0
online_dataset_frames.append({**observation, **{"intervention": intervention}, **{"action": action}})

all_actions.append(action)
all_successes.append(torch.tensor([False]))

for key in image_keys:
if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
images.append(observation[key].to("mps"))

reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
all_rewards.append(reward)

# print("REWARD : ", reward)

all_actions.append(action)
all_successes.append(torch.tensor([False]))

dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
# Maintain fps
busy_wait(1 / fps - (time.perf_counter() - start_loop_t))

timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
events["human_intervention_step"] = False
events["pause_policy"] = False
break

# Reset robot position
if manual_reset_time_s is not None:
log_say("Manual environment reset.", play_sounds=True)
manual_reset_follower_position(robot, manual_reset_time_s)
else:
log_say("Automatic environment reset.", play_sounds=True)
reset_follower_position(robot, target_position=init_pos)

# Compute rewards
log_say("Comuputing rewards")
reward_progbar = trange(len(all_observations), desc="Comuputing rewards")
for i in reward_progbar:
# Preprocess images
images = []
for img_key in image_keys:
img = deepcopy(observation[img_key])
img = img.permute(2, 0, 1) if img.dim() == 3 and img.shape[-1] == 3 else img
img = img.unsqueeze(0)
if image_transforms:
img = image_transforms(img)
images.append(img.to(reward_classifier.device))

with torch.inference_mode():
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
all_rewards.append(reward)
if online_dataset is not None:
online_dataset_frames[i]["next.reward"] = reward
online_dataset.add_frame(online_dataset_frames[i])

reset_follower_position(robot, target_position=init_pos)

# Prepare return data
dones = torch.tensor([False] * len(all_actions))
dones[-1] = True
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
rollout_data = {
"action": torch.stack(all_actions, dim=1),
"next.reward": torch.stack(all_rewards, dim=1),
"next.success": torch.stack(all_successes, dim=1),
Expand All @@ -186,7 +225,7 @@ def rollout(

listener.stop()

return ret
return rollout_data, online_dataset


def eval_policy(
Expand All @@ -199,6 +238,10 @@ def eval_policy(
display_cameras: bool = False,
reward_classifier_pretrained_path: str | None = None,
reward_classifier_config_file: str | None = None,
device: str = "mps",
manual_reset_time_s: float | None = None,
image_transforms: transforms.Compose | None = None,
online_dataset: LeRobotDataset | None = None,
) -> dict:
"""
Args:
Expand All @@ -209,24 +252,22 @@ def eval_policy(
Dictionary with metrics and data regarding the rollouts.
"""
# TODO (michel-aractingi) comment this out for testing with a fixed policy
# assert isinstance(policy, Policy)
# policy.eval()
assert isinstance(policy, Policy)
policy.eval()

sum_rewards = []
max_rewards = []
successes = []
rollouts = []

start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file, device)

for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
rollout_data, online_dataset = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras, manual_reset_time_s=manual_reset_time_s, image_transforms=image_transforms, online_dataset=online_dataset
)

rollouts.append(rollout_data)
sum_rewards.append(sum(rollout_data["next.reward"]))
max_rewards.append(max(rollout_data["next.reward"]))
successes.append(rollout_data["next.success"][-1])
Expand Down Expand Up @@ -260,7 +301,7 @@ def eval_policy(
if robot.is_connected:
robot.disconnect()

return info
return info, online_dataset


def init_keyboard_listener():
Expand Down