Skip to content

Commit

Permalink
Add Pi0 (#681)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Alibert <[email protected]>
Co-authored-by: Simon Alibert <[email protected]>
Co-authored-by: Pablo <[email protected]>
  • Loading branch information
4 people authored Feb 4, 2025
1 parent dd97452 commit 638d411
Show file tree
Hide file tree
Showing 26 changed files with 2,365 additions and 92 deletions.
52 changes: 26 additions & 26 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ test-end-to-end:
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online

test-act-ete-train:
python lerobot/scripts/train.py \
Expand Down Expand Up @@ -128,28 +127,29 @@ test-tdmpc-ete-eval:
--eval.batch_size=1 \
--device=$(DEVICE)

test-tdmpc-ete-train-with-online:
python lerobot/scripts/train.py \
--policy.type=tdmpc \
--env.type=pusht \
--env.obs_type=environment_state_agent_pos \
--env.episode_length=5 \
--dataset.repo_id=lerobot/pusht_keypoints \
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=20 \
--online.rollout_n_episodes=2 \
--online.rollout_batch_size=2 \
--online.steps_between_rollouts=10 \
--online.buffer_capacity=1000 \
--online.env_seed=10000 \
--save_checkpoint=false \
--save_freq=10 \
--log_freq=1 \
--eval.use_async_envs=true \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--device=$(DEVICE) \
--output_dir=tests/outputs/tdmpc_online/
# TODO(rcadene): fix online buffer to storing "task"
# test-tdmpc-ete-train-with-online:
# python lerobot/scripts/train.py \
# --policy.type=tdmpc \
# --env.type=pusht \
# --env.obs_type=environment_state_agent_pos \
# --env.episode_length=5 \
# --dataset.repo_id=lerobot/pusht_keypoints \
# --dataset.image_transforms.enable=true \
# --dataset.episodes="[0]" \
# --batch_size=2 \
# --offline.steps=2 \
# --online.steps=20 \
# --online.rollout_n_episodes=2 \
# --online.rollout_batch_size=2 \
# --online.steps_between_rollouts=10 \
# --online.buffer_capacity=1000 \
# --online.env_seed=10000 \
# --save_checkpoint=false \
# --save_freq=10 \
# --log_freq=1 \
# --eval.use_async_envs=true \
# --eval.n_episodes=1 \
# --eval.batch_size=1 \
# --device=$(DEVICE) \
# --output_dir=tests/outputs/tdmpc_online/
4 changes: 4 additions & 0 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,10 @@ def __getitem__(self, idx) -> dict:
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])

# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]

return item

def __repr__(self):
Expand Down
18 changes: 16 additions & 2 deletions lerobot/common/optim/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
lr: float
betas: tuple[float, float]
eps: float
weight_decay: float
grad_clip_norm: float

Expand Down Expand Up @@ -54,3 +52,19 @@ def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.AdamW(params, **kwargs)


@OptimizerConfig.register_subclass("sgd")
@dataclass
class SGDConfig(OptimizerConfig):
lr: float = 1e-3
momentum: float = 0.0
dampening: float = 0.0
nesterov: bool = False
weight_decay: float = 0.0
grad_clip_norm: float = 10.0

def build(self, params: dict) -> torch.optim.Optimizer:
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
35 changes: 35 additions & 0 deletions lerobot/common/optim/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,38 @@ def lr_lambda(current_step):
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))

return LambdaLR(optimizer, lr_lambda, -1)


@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup")
@dataclass
class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
"""Used by Physical Intelligence to train Pi0"""

num_warmup_steps: int
num_decay_steps: int
peak_lr: float
decay_lr: float

def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
del num_training_steps

def lr_lambda(current_step):
def linear_warmup_schedule(current_step):
if current_step <= 0:
return 1 / (self.num_warmup_steps + 1)
frac = 1 - current_step / self.num_warmup_steps
return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1

def cosine_decay_schedule(current_step):
step = min(current_step, self.num_decay_steps)
cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps))
alpha = self.decay_lr / self.peak_lr
decayed = (1 - alpha) * cosine_decay + alpha
return decayed

if current_step < self.num_warmup_steps:
return linear_warmup_schedule(current_step)

return cosine_decay_schedule(current_step)

return LambdaLR(optimizer, lr_lambda, -1)
1 change: 1 addition & 0 deletions lerobot/common/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
9 changes: 9 additions & 0 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
Expand All @@ -50,6 +51,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy

return VQBeTPolicy
elif name == "pi0":
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy

return PI0Policy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")

Expand All @@ -63,6 +68,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")

Expand Down Expand Up @@ -141,4 +148,6 @@ def make_policy(
policy.to(device)
assert isinstance(policy, nn.Module)

# policy = torch.compile(policy, mode="reduce-overhead")

return policy
6 changes: 6 additions & 0 deletions lerobot/common/policies/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ def __init__(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue

norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
Expand Down Expand Up @@ -210,6 +213,9 @@ def __init__(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
continue

norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
Expand Down
134 changes: 134 additions & 0 deletions lerobot/common/policies/pi0/configuration_pi0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from dataclasses import dataclass, field

from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature


@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50

normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)

# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32

# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)

# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0

# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False

# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False

# Tokenizer
tokenizer_max_length: int = 48

# Projector
proj_width: int = 1024

# Decoding
num_steps: int = 10

# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex

# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True

# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10

scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6

# TODO: Add EMA

def __post_init__(self):
super().__post_init__()

"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)

if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)

def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")

for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera

def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)

def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)

@property
def observation_delta_indices(self) -> None:
return None

@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))

@property
def reward_delta_indices(self) -> None:
return None
Loading

0 comments on commit 638d411

Please sign in to comment.