-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Port HIL SERL #644
base: main
Are you sure you want to change the base?
[WIP] Port HIL SERL #644
Conversation
b1be31a
to
2211209
Compare
lerobot/common/policies/hilserl/classifier/modeling_classifier.py
Outdated
Show resolved
Hide resolved
cfg.env.wrapper.ee_action_space_params is not None | ||
and cfg.env.wrapper.ee_action_space_params.use_gamepad | ||
): | ||
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WHat's the reason for the action scale wrapper being commented out?
9a68f20
to
ae12807
Compare
dd50635
to
313812d
Compare
Co-authored-by: Daniel Ritchie <[email protected]> Co-authored-by: resolver101757 <[email protected]> Co-authored-by: Jannik Grothusen <[email protected]> Co-authored-by: Remi <[email protected]> Co-authored-by: Michel Aractingi <[email protected]>
…licy on the robot (#541) Co-authored-by: Yoel <[email protected]>
Co-authored-by: Yoel <[email protected]>
Co-authored-by: KeWang1017 <[email protected]>
…ing logic - Added `num_subsample_critics`, `critic_target_update_weight`, and `utd_ratio` to SACConfig. - Implemented target entropy calculation in SACPolicy if not provided. - Introduced subsampling of critics to prevent overfitting during updates. - Updated temperature loss calculation to use the new target entropy. - Added comments for future UTD update implementation. These changes improve the flexibility and performance of the SAC implementation.
…s & check script (#578)
…n handling - Updated action selection to use distribution sampling and log probabilities for better stochastic behavior. - Enhanced standard deviation clamping to prevent extreme values, ensuring stability in policy outputs. - Cleaned up code by removing unnecessary comments and improving readability. These changes aim to refine the SAC implementation, enhancing its robustness and performance during training and inference.
- Updated standard deviation parameterization in SACConfig to 'softplus' with defined min and max values for improved stability. - Modified action sampling in SACPolicy to use reparameterized sampling, ensuring better gradient flow and log probability calculations. - Cleaned up log probability calculations in TanhMultivariateNormalDiag for clarity and efficiency. - Increased evaluation frequency in YAML configuration to 50000 for more efficient training cycles. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
…d stability - Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions. - Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior. - Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability. These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
Added support for hil_serl classifier to be trained with train.py run classifier training by python lerobot/scripts/train.py --policy.type=hilserl_classifier fixes in find_joint_limits, control_robot, end_effector_control_utils
…rties - Introduced `WrapperConfig` dataclass for environment wrapper configurations. - Updated `ManiskillEnvConfig` to include a `wrapper` field for enhanced environment management. - Modified `SACConfig` to return `None` for `observation_delta_indices` and `action_delta_indices` properties. - Refactored `make_robot_env` function to improve readability and maintainability.
Moved HilSerl env config to configs/env/configs.py fixes in actor_server and modeling_sac and configuration_sac added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
- Implemented process-specific logging for actor and learner servers to improve traceability. - Created a dedicated logs directory and ensured it exists before logging. - Initialized logging with explicit log files for each process, including actor transitions, interactions, and policy. - Updated the actor CLI to validate configuration and set up logging accordingly.
- Simplified the `image_features` property to directly iterate over `input_features`. - Removed unused imports and unnecessary code related to main execution, enhancing clarity and maintainability.
- Rearranged import statements for better readability. - Removed unused imports and streamlined the code structure.
- Removed unused imports and streamlined the code structure. - Consolidated logging initialization and enhanced logging for training processes. - Improved handling of training state loading and resume logic. - Refactored transition and interaction message processing for better readability and maintainability. - Added detailed comments and documentation for clarity.
- Consolidated logging initialization and enhanced logging for actor processes. - Streamlined the handling of gRPC connections and process management. - Improved readability by organizing core algorithm functions and communication functions. - Added detailed comments and documentation for clarity. - Ensured proper queue management and shutdown handling for actor processes.
…onality - Updated the `forward` method in `SACPolicy` to handle loss computation for actor, critic, and temperature models. - Replaced direct calls to `compute_loss_*` methods with a unified `forward` method in `learner_server`. - Enhanced batch processing by consolidating input parameters into a single dictionary for better readability and maintainability. - Removed redundant code and improved documentation for clarity.
- Enhanced type annotations for variables in the `SACPolicy` class to improve code clarity. - Updated method calls to use keyword arguments for better readability. - Streamlined the extraction of batch components, ensuring consistent typing across the class methods.
…f gamepad Minor modifications in gym_manipulator to quantize the gripper actions clamped the observations after F.resize in ConvertToLeRobotObservation wrapper due to a bug in F.resize, images were returned exceeding the maximum value of 1.0
ad51d89
to
808cf63
Compare
for more information, see https://pre-commit.ci
|
||
|
||
@dataclass | ||
class HILSerlConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we drop it?
from huggingface_hub import PyTorchModelHubMixin | ||
|
||
|
||
class HILSerlPolicy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that it oculd be dropped too
current_position = robot.follower_arms["main"].read("Present_Position") | ||
trajectory = torch.from_numpy( | ||
np.linspace(current_position, target_position, 50) | ||
) # NOTE: 30 is just an aribtrary number |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -246,14 +263,21 @@ def control_loop( | |||
while timestamp < control_time_s: | |||
start_loop_t = time.perf_counter() | |||
|
|||
current_joint_positions = robot.follower_arms["main"].read("Present_Position") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for more information, see https://pre-commit.ci
@@ -0,0 +1,594 @@ | |||
#!/usr/bin/env python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we drop the file?
|
||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB | ||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions | ||
STUTDOWN_TIMEOUT = 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
STUTDOWN_TIMEOUT = 10 | |
SHUTDOWN_TIMEOUT = 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linter fix
|
||
shutdown_event.wait() | ||
logging.info("[LEARNER] Stopping gRPC server...") | ||
server.stop(learner_service.STUTDOWN_TIMEOUT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
server.stop(learner_service.STUTDOWN_TIMEOUT) | |
server.stop(learner_service.SHUTDOWN_TIMEOUT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Linter fix, should me merged together with another fix for STUTDOWN_TIMEOUT
in learner_service
@@ -107,8 +106,9 @@ def validate(self): | |||
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" | |||
self.output_dir = Path("outputs/train") / train_dir | |||
|
|||
if isinstance(self.dataset.repo_id, list): | |||
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") | |||
if self.dataset is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -0,0 +1,412 @@ | |||
#!/usr/bin/env python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this file could be dropped too
) | ||
|
||
# Initialize kinematics instance for the appropriate robot type | ||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") | |
robot_type = getattr(env.unwrapped.robot.config, "type", "so100") |
Also I'm a bit against the getattr with default so100 as it leads to silent bugs. Took me a minute to realize why my moss was constrained
self.use_gripper = use_gripper | ||
|
||
# Initialize kinematics instance for the appropriate robot type | ||
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100") | |
robot_type = getattr(env.unwrapped.robot.config, "type", "so100") |
if not self.robot.is_connected: | ||
self.robot.connect() | ||
|
||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this is not used
What this does
We evaluate the actor-learner architecture on ManiSkill.
Implements the actor-learner process:
Increases learning speed by 50% using a shared encoder for the ensemble critics.
How it was tested
How to check out & try it (for the reviewer) 😃
Examples: