Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Port HIL SERL #644

Open
wants to merge 122 commits into
base: main
Choose a base branch
from

Conversation

AdilZouitine
Copy link
Member

@AdilZouitine AdilZouitine commented Jan 17, 2025

What this does

⚠️ This PR is not ready to be merged.

We evaluate the actor-learner architecture on ManiSkill.

  • Implements the actor-learner process:

    1. An actor machine interacts with the environment and sends data to a learner machine.
    2. The learner updates its weights using this data and sends the updated weights back to the actor.
  • Increases learning speed by 50% using a shared encoder for the ensemble critics.

    • Previously, each critic made a separate forward pass through the encoder, duplicating work.
    • Now, the observation is passed through the encoder only once, and the resulting representation is sent to the critic heads.

How it was tested

  • We trained an agent on ManiSkill using this actor-learner architecture.

How to check out & try it (for the reviewer) 😃

  • Install ManiSkill.

Examples:

python lerobot/scripts/server/actor_server.py policy=sac_maniskill env=maniskill_example device=cuda wandb.enable=True

python lerobot/scripts/server/learner_server.py policy=sac_maniskill env=maniskill_example device=cuda wandb.enable=True

@michel-aractingi michel-aractingi force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch from b1be31a to 2211209 Compare February 3, 2025 15:11
cfg.env.wrapper.ee_action_space_params is not None
and cfg.env.wrapper.ee_action_space_params.use_gamepad
):
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WHat's the reason for the action scale wrapper being commented out?

@AdilZouitine AdilZouitine changed the title [WIP] Fix SAC and port HIL SERL [WIP] Port HIL SERL Mar 18, 2025
@AdilZouitine AdilZouitine force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch from 9a68f20 to ae12807 Compare March 24, 2025 11:05
@AdilZouitine AdilZouitine changed the base branch from user/michel-aractingi/2024-11-27-port-hil-serl to main March 24, 2025 11:07
@AdilZouitine AdilZouitine force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch from dd50635 to 313812d Compare March 24, 2025 13:16
ChorntonYoel and others added 17 commits March 28, 2025 17:18
Co-authored-by: Daniel Ritchie <[email protected]>
Co-authored-by: resolver101757 <[email protected]>
Co-authored-by: Jannik Grothusen <[email protected]>
Co-authored-by: Remi <[email protected]>
Co-authored-by: Michel Aractingi <[email protected]>
Co-authored-by: KeWang1017 <[email protected]>
…ing logic

- Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig.
- Implemented target entropy calculation in SACPolicy if not provided.
- Introduced subsampling of critics to prevent overfitting during updates.
- Updated temperature loss calculation to use the new target entropy.
- Added comments for future UTD update implementation.

These changes improve the flexibility and performance of the SAC implementation.
…n handling

- Updated action selection to use distribution sampling and log probabilities for better stochastic behavior.
- Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs.
- Cleaned up code by removing unnecessary comments and improving readability.

These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability.
- Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations.
- Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency.
- Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
…d stability

- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
michel-aractingi and others added 12 commits March 28, 2025 17:18
Added support for hil_serl classifier to be trained with train.py
run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier
fixes in find_joint_limits, control_robot, end_effector_control_utils
…rties

- Introduced `WrapperConfig` dataclass for environment wrapper configurations.
- Updated `ManiskillEnvConfig` to include a `wrapper` field for enhanced environment management.
- Modified `SACConfig` to return `None` for `observation_delta_indices` and `action_delta_indices` properties.
- Refactored `make_robot_env` function to improve readability and maintainability.
Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
- Implemented process-specific logging for actor and learner servers to improve traceability.
- Created a dedicated logs directory and ensured it exists before logging.
- Initialized logging with explicit log files for each process, including actor transitions, interactions, and policy.
- Updated the actor CLI to validate configuration and set up logging accordingly.
- Simplified the `image_features` property to directly iterate over `input_features`.
- Removed unused imports and unnecessary code related to main execution, enhancing clarity and maintainability.
- Rearranged import statements for better readability.
- Removed unused imports and streamlined the code structure.
- Removed unused imports and streamlined the code structure.
- Consolidated logging initialization and enhanced logging for training processes.
- Improved handling of training state loading and resume logic.
- Refactored transition and interaction message processing for better readability and maintainability.
- Added detailed comments and documentation for clarity.
- Consolidated logging initialization and enhanced logging for actor processes.
- Streamlined the handling of gRPC connections and process management.
- Improved readability by organizing core algorithm functions and communication functions.
- Added detailed comments and documentation for clarity.
- Ensured proper queue management and shutdown handling for actor processes.
…onality

- Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models.
- Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`.
- Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability.
- Removed redundant code and improved documentation for clarity.
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity.
- Updated method calls to use keyword arguments for better readability.
- Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
…f gamepad

Minor modifications in gym_manipulator to quantize the gripper actions
clamped the observations after F.resize in ConvertToLeRobotObservation wrapper due to a bug in F.resize, images were returned exceeding the maximum value of 1.0
@AdilZouitine AdilZouitine force-pushed the user/adil-zouitine/2025-1-7-port-hil-serl-new branch from ad51d89 to 808cf63 Compare March 28, 2025 17:20


@dataclass
class HILSerlConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we drop it?

from huggingface_hub import PyTorchModelHubMixin


class HILSerlPolicy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that it oculd be dropped too

current_position = robot.follower_arms["main"].read("Present_Position")
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an aribtrary number
Copy link
Contributor

@helper2424 helper2424 Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, GH suggests to fix, so a small fix

Suggested change
) # NOTE: 30 is just an aribtrary number
) # NOTE: 30 is just an arbitrary number
Screenshot 2025-03-29 at 1 22 44 AM

@@ -246,14 +263,21 @@ def control_loop(
while timestamp < control_time_s:
start_loop_t = time.perf_counter()

current_joint_positions = robot.follower_arms["main"].read("Present_Position")
Copy link
Contributor

@helper2424 helper2424 Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A small fix, linter generate an error

Suggested change
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
# current_joint_positions = robot.follower_arms["main"].read("Present_Position")
Screenshot 2025-03-29 at 1 25 25 AM

@@ -0,0 +1,594 @@
#!/usr/bin/env python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we drop the file?


MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
STUTDOWN_TIMEOUT = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
STUTDOWN_TIMEOUT = 10
SHUTDOWN_TIMEOUT = 10

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter fix


shutdown_event.wait()
logging.info("[LEARNER] Stopping gRPC server...")
server.stop(learner_service.STUTDOWN_TIMEOUT)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
server.stop(learner_service.STUTDOWN_TIMEOUT)
server.stop(learner_service.SHUTDOWN_TIMEOUT)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Linter fix, should me merged together with another fix for STUTDOWN_TIMEOUT in learner_service

@@ -107,8 +106,9 @@ def validate(self):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir

if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if self.dataset is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,412 @@
#!/usr/bin/env python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this file could be dropped too

)

# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
robot_type = getattr(env.unwrapped.robot.config, "type", "so100")

Also I'm a bit against the getattr with default so100 as it leads to silent bugs. Took me a minute to realize why my moss was constrained

self.use_gripper = use_gripper

# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
robot_type = getattr(env.unwrapped.robot.config, "type", "so100")

if not self.robot.is_connected:
self.robot.connect()

self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this is not used

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants