-
Notifications
You must be signed in to change notification settings - Fork 953
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Simon Alibert <[email protected]> Co-authored-by: Pablo <[email protected]>
- Loading branch information
1 parent
dd97452
commit 638d411
Showing
26 changed files
with
2,365 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .act.configuration_act import ACTConfig as ACTConfig | ||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig | ||
from .pi0.configuration_pi0 import PI0Config as PI0Config | ||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig | ||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from dataclasses import dataclass, field | ||
|
||
from lerobot.common.optim.optimizers import AdamWConfig | ||
from lerobot.common.optim.schedulers import ( | ||
CosineDecayWithWarmupSchedulerConfig, | ||
) | ||
from lerobot.configs.policies import PreTrainedConfig | ||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature | ||
|
||
|
||
@PreTrainedConfig.register_subclass("pi0") | ||
@dataclass | ||
class PI0Config(PreTrainedConfig): | ||
# Input / output structure. | ||
n_obs_steps: int = 1 | ||
chunk_size: int = 50 | ||
n_action_steps: int = 50 | ||
|
||
normalization_mapping: dict[str, NormalizationMode] = field( | ||
default_factory=lambda: { | ||
"VISUAL": NormalizationMode.IDENTITY, | ||
"STATE": NormalizationMode.MEAN_STD, | ||
"ACTION": NormalizationMode.MEAN_STD, | ||
} | ||
) | ||
|
||
# Shorter state and action vectors will be padded | ||
max_state_dim: int = 32 | ||
max_action_dim: int = 32 | ||
|
||
# Image preprocessing | ||
resize_imgs_with_padding: tuple[int, int] = (224, 224) | ||
|
||
# Add empty images. Used by pi0_aloha_sim which adds the empty | ||
# left and right wrist cameras in addition to the top camera. | ||
empty_cameras: int = 0 | ||
|
||
# Converts the joint and gripper values from the standard Aloha space to | ||
# the space used by the pi internal runtime which was used to train the base model. | ||
adapt_to_pi_aloha: bool = False | ||
|
||
# Converts joint dimensions to deltas with respect to the current state before passing to the model. | ||
# Gripper dimensions will remain in absolute values. | ||
use_delta_joint_actions_aloha: bool = False | ||
|
||
# Tokenizer | ||
tokenizer_max_length: int = 48 | ||
|
||
# Projector | ||
proj_width: int = 1024 | ||
|
||
# Decoding | ||
num_steps: int = 10 | ||
|
||
# Attention utils | ||
use_cache: bool = True | ||
attention_implementation: str = "eager" # or fa2, flex | ||
|
||
# Finetuning settings | ||
freeze_vision_encoder: bool = True | ||
train_expert_only: bool = False | ||
train_state_proj: bool = True | ||
|
||
# Training presets | ||
optimizer_lr: float = 2.5e-5 | ||
optimizer_betas: tuple[float, float] = (0.9, 0.95) | ||
optimizer_eps: float = 1e-8 | ||
optimizer_weight_decay: float = 1e-10 | ||
|
||
scheduler_warmup_steps: int = 1_000 | ||
scheduler_decay_steps: int = 30_000 | ||
scheduler_decay_lr: float = 2.5e-6 | ||
|
||
# TODO: Add EMA | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
||
"""Input validation (not exhaustive).""" | ||
if self.n_action_steps > self.chunk_size: | ||
raise ValueError( | ||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got " | ||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." | ||
) | ||
if self.n_obs_steps != 1: | ||
raise ValueError( | ||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" | ||
) | ||
|
||
if self.use_delta_joint_actions_aloha: | ||
raise NotImplementedError( | ||
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." | ||
) | ||
|
||
def validate_features(self) -> None: | ||
# TODO: implement value error | ||
# if not self.image_features and not self.env_state_feature: | ||
# raise ValueError("You must provide at least one image or the environment state among the inputs.") | ||
|
||
for i in range(self.empty_cameras): | ||
key = f"observation.images.empty_camera_{i}" | ||
empty_camera = PolicyFeature( | ||
type=FeatureType.VISUAL, | ||
shape=(3, 480, 640), | ||
) | ||
self.input_features[key] = empty_camera | ||
|
||
def get_optimizer_preset(self) -> AdamWConfig: | ||
return AdamWConfig( | ||
lr=self.optimizer_lr, | ||
betas=self.optimizer_betas, | ||
eps=self.optimizer_eps, | ||
weight_decay=self.optimizer_weight_decay, | ||
) | ||
|
||
def get_scheduler_preset(self): | ||
return CosineDecayWithWarmupSchedulerConfig( | ||
peak_lr=self.optimizer_lr, | ||
decay_lr=self.scheduler_decay_lr, | ||
num_warmup_steps=self.scheduler_warmup_steps, | ||
num_decay_steps=self.scheduler_decay_steps, | ||
) | ||
|
||
@property | ||
def observation_delta_indices(self) -> None: | ||
return None | ||
|
||
@property | ||
def action_delta_indices(self) -> list: | ||
return list(range(self.chunk_size)) | ||
|
||
@property | ||
def reward_delta_indices(self) -> None: | ||
return None |
Oops, something went wrong.