diff --git a/Makefile b/Makefile index 530eca2ec..c216e009d 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,6 @@ test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval - ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train-with-online test-act-ete-train: python lerobot/scripts/train.py \ @@ -128,28 +127,29 @@ test-tdmpc-ete-eval: --eval.batch_size=1 \ --device=$(DEVICE) -test-tdmpc-ete-train-with-online: - python lerobot/scripts/train.py \ - --policy.type=tdmpc \ - --env.type=pusht \ - --env.obs_type=environment_state_agent_pos \ - --env.episode_length=5 \ - --dataset.repo_id=lerobot/pusht_keypoints \ - --dataset.image_transforms.enable=true \ - --dataset.episodes="[0]" \ - --batch_size=2 \ - --offline.steps=2 \ - --online.steps=20 \ - --online.rollout_n_episodes=2 \ - --online.rollout_batch_size=2 \ - --online.steps_between_rollouts=10 \ - --online.buffer_capacity=1000 \ - --online.env_seed=10000 \ - --save_checkpoint=false \ - --save_freq=10 \ - --log_freq=1 \ - --eval.use_async_envs=true \ - --eval.n_episodes=1 \ - --eval.batch_size=1 \ - --device=$(DEVICE) \ - --output_dir=tests/outputs/tdmpc_online/ +# TODO(rcadene): fix online buffer to storing "task" +# test-tdmpc-ete-train-with-online: +# python lerobot/scripts/train.py \ +# --policy.type=tdmpc \ +# --env.type=pusht \ +# --env.obs_type=environment_state_agent_pos \ +# --env.episode_length=5 \ +# --dataset.repo_id=lerobot/pusht_keypoints \ +# --dataset.image_transforms.enable=true \ +# --dataset.episodes="[0]" \ +# --batch_size=2 \ +# --offline.steps=2 \ +# --online.steps=20 \ +# --online.rollout_n_episodes=2 \ +# --online.rollout_batch_size=2 \ +# --online.steps_between_rollouts=10 \ +# --online.buffer_capacity=1000 \ +# --online.env_seed=10000 \ +# --save_checkpoint=false \ +# --save_freq=10 \ +# --log_freq=1 \ +# --eval.use_async_envs=true \ +# --eval.n_episodes=1 \ +# --eval.batch_size=1 \ +# --device=$(DEVICE) \ +# --output_dir=tests/outputs/tdmpc_online/ diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5a5bd137c..9483bf0a9 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -672,6 +672,10 @@ def __getitem__(self, idx) -> dict: for cam in image_keys: item[cam] = self.image_transforms(item[cam]) + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self.meta.tasks[task_idx] + return item def __repr__(self): diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 8a19615b6..737305ad0 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -8,8 +8,6 @@ @dataclass class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): lr: float - betas: tuple[float, float] - eps: float weight_decay: float grad_clip_norm: float @@ -54,3 +52,19 @@ def build(self, params: dict) -> torch.optim.Optimizer: kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.AdamW(params, **kwargs) + + +@OptimizerConfig.register_subclass("sgd") +@dataclass +class SGDConfig(OptimizerConfig): + lr: float = 1e-3 + momentum: float = 0.0 + dampening: float = 0.0 + nesterov: bool = False + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + + def build(self, params: dict) -> torch.optim.Optimizer: + kwargs = asdict(self) + kwargs.pop("grad_clip_norm") + return torch.optim.SGD(params, **kwargs) diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 752ddeff2..80d83bdfc 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -54,3 +54,38 @@ def lr_lambda(current_step): return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress))) return LambdaLR(optimizer, lr_lambda, -1) + + +@LRSchedulerConfig.register_subclass("cosine_decay_with_warmup") +@dataclass +class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): + """Used by Physical Intelligence to train Pi0""" + + num_warmup_steps: int + num_decay_steps: int + peak_lr: float + decay_lr: float + + def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR: + del num_training_steps + + def lr_lambda(current_step): + def linear_warmup_schedule(current_step): + if current_step <= 0: + return 1 / (self.num_warmup_steps + 1) + frac = 1 - current_step / self.num_warmup_steps + return (1 / (self.num_warmup_steps + 1) - 1) * frac + 1 + + def cosine_decay_schedule(current_step): + step = min(current_step, self.num_decay_steps) + cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.num_decay_steps)) + alpha = self.decay_lr / self.peak_lr + decayed = (1 - alpha) * cosine_decay + alpha + return decayed + + if current_step < self.num_warmup_steps: + return linear_warmup_schedule(current_step) + + return cosine_decay_schedule(current_step) + + return LambdaLR(optimizer, lr_lambda, -1) diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index 58db9849f..2e4486efc 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -1,4 +1,5 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig +from .pi0.configuration_pi0 import PI0Config as PI0Config from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index fb5b1159d..cd440f7a0 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -25,6 +25,7 @@ from lerobot.common.envs.utils import env_to_policy_features from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig @@ -50,6 +51,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy + elif name == "pi0": + from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy + + return PI0Policy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -63,6 +68,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return ACTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) + elif policy_type == "pi0": + return PI0Config(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") @@ -141,4 +148,6 @@ def make_policy( policy.to(device) assert isinstance(policy, nn.Module) + # policy = torch.compile(policy, mode="reduce-overhead") + return policy diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 0d188d815..952192736 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -140,6 +140,9 @@ def __init__( def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): + if key not in batch: + continue + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) if norm_mode is NormalizationMode.IDENTITY: continue @@ -210,6 +213,9 @@ def __init__( def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): + if key not in batch: + continue + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) if norm_mode is NormalizationMode.IDENTITY: continue diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py new file mode 100644 index 000000000..8d2eedf69 --- /dev/null +++ b/lerobot/common/policies/pi0/configuration_pi0.py @@ -0,0 +1,134 @@ +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +@PreTrainedConfig.register_subclass("pi0") +@dataclass +class PI0Config(PreTrainedConfig): + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Image preprocessing + resize_imgs_with_padding: tuple[int, int] = (224, 224) + + # Add empty images. Used by pi0_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Converts the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi_aloha: bool = False + + # Converts joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions_aloha: bool = False + + # Tokenizer + tokenizer_max_length: int = 48 + + # Projector + proj_width: int = 1024 + + # Decoding + num_steps: int = 10 + + # Attention utils + use_cache: bool = True + attention_implementation: str = "eager" # or fa2, flex + + # Finetuning settings + freeze_vision_encoder: bool = True + train_expert_only: bool = False + train_state_proj: bool = True + + # Training presets + optimizer_lr: float = 2.5e-5 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + # TODO: Add EMA + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.n_obs_steps != 1: + raise ValueError( + f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" + ) + + if self.use_delta_joint_actions_aloha: + raise NotImplementedError( + "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." + ) + + def validate_features(self) -> None: + # TODO: implement value error + # if not self.image_features and not self.env_state_feature: + # raise ValueError("You must provide at least one image or the environment state among the inputs.") + + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py new file mode 100644 index 000000000..31bd1b66a --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py @@ -0,0 +1,68 @@ +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.configs.policies import PreTrainedConfig + +torch.backends.cudnn.benchmark = True + + +def main(): + device = "cuda" + dataset_repo_id = "danaaubakirova/koch_test" + # model_name = "pi0_base" + # ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" + ckpt_torch_dir = "lerobot/pi0" + + dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=1, + ) + + batch = next(iter(dataloader)) + + # To device + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device=device, dtype=torch.float32) + + cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) + cfg.pretrained_path = ckpt_torch_dir + policy = make_policy(cfg, device, ds_meta=dataset.meta) + + # policy = torch.compile(policy, mode="reduce-overhead") + + warmup_iters = 10 + benchmark_iters = 30 + + # Warmup + for _ in range(warmup_iters): + torch.cuda.synchronize() + policy.select_action(batch) + policy.reset() + torch.cuda.synchronize() + + # Benchmark + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(benchmark_iters): + policy.select_action(batch) + policy.reset() + end_event.record() + + # Synchronize and measure time + torch.cuda.synchronize() + elapsed_time_ms = start_event.elapsed_time(end_event) + + avg_time_per_iter = elapsed_time_ms / benchmark_iters + print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms") + + +if __name__ == "__main__": + with torch.inference_mode(): + main() diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py new file mode 100644 index 000000000..8b2e1c663 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -0,0 +1,117 @@ +import json +import pickle +from pathlib import Path + +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.policies.factory import make_policy +from lerobot.configs.policies import PreTrainedConfig + + +def display(tensor: torch.Tensor): + if tensor.dtype == torch.bool: + tensor = tensor.float() + print(f"Shape: {tensor.shape}") + print(f"Mean: {tensor.mean().item()}") + print(f"Std: {tensor.std().item()}") + print(f"Min: {tensor.min().item()}") + print(f"Max: {tensor.max().item()}") + + +def main(): + num_motors = 14 + device = "cuda" + # model_name = "pi0_aloha_towel" + model_name = "pi0_aloha_sim" + + if model_name == "pi0_aloha_towel": + dataset_repo_id = "lerobot/aloha_static_towel" + else: + dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" + + ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" + ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" + save_dir = Path(f"../openpi/data/{model_name}/save") + + with open(save_dir / "example.pkl", "rb") as f: + example = pickle.load(f) + with open(save_dir / "outputs.pkl", "rb") as f: + outputs = pickle.load(f) + with open(save_dir / "noise.pkl", "rb") as f: + noise = pickle.load(f) + + with open(ckpt_jax_dir / "assets/norm_stats.json") as f: + norm_stats = json.load(f) + + # Override stats + dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) + dataset_meta.stats["observation.state"]["mean"] = torch.tensor( + norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 + ) + dataset_meta.stats["observation.state"]["std"] = torch.tensor( + norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 + ) + + # Create LeRobot batch from Jax + batch = {} + for cam_key, uint_chw_array in example["images"].items(): + batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 + batch["observation.state"] = torch.from_numpy(example["state"]) + batch["action"] = torch.from_numpy(outputs["actions"]) + batch["task"] = example["prompt"] + + if model_name == "pi0_aloha_towel": + del batch["observation.images.cam_low"] + elif model_name == "pi0_aloha_sim": + batch["observation.images.top"] = batch["observation.images.cam_high"] + del batch["observation.images.cam_high"] + + # Batchify + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].unsqueeze(0) + elif isinstance(batch[key], str): + batch[key] = [batch[key]] + else: + raise ValueError(f"{key}, {batch[key]}") + + # To device + for k in batch: + if isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device=device, dtype=torch.float32) + + noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32) + + from lerobot.common import policies # noqa + + cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) + cfg.pretrained_path = ckpt_torch_dir + policy = make_policy(cfg, device, dataset_meta) + + # loss_dict = policy.forward(batch, noise=noise, time=time_beta) + # loss_dict["loss"].backward() + # print("losses") + # display(loss_dict["losses_after_forward"]) + # print("pi_losses") + # display(pi_losses) + + actions = [] + for _ in range(50): + action = policy.select_action(batch, noise=noise) + actions.append(action) + + actions = torch.stack(actions, dim=1) + pi_actions = batch["action"] + print("actions") + display(actions) + print() + print("pi_actions") + display(pi_actions) + print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2)) + print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2)) + print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2)) + + +if __name__ == "__main__": + main() diff --git a/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py new file mode 100644 index 000000000..8e35d0d47 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/conversion_utils.py @@ -0,0 +1,70 @@ +from transformers import GemmaConfig, PaliGemmaConfig + + +def get_paligemma_config(precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} + + image_size = 224 # image_sizes[variant] + patch_size = 14 + num_image_tokens = (image_size**2) // (patch_size**2) + + config["image_token_index"] = 257152 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 2048, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 16384, + "is_encoder_decoder": False, + } + vision_config = { + "torch_dtype": precision, + "image_size": image_size, + "patch_size": patch_size, + "num_image_tokens": num_image_tokens, + "hidden_size": 1152, + "intermediate_size": 4304, + "num_hidden_layers": 27, + "num_attention_heads": 16, + "projector_hidden_act": "gelu_fast", + "vision_use_head": False, + } + final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) + return final_config + + +def get_gemma_config(precision: str): + config = { + "image_token_index": None, + "pad_token_id": 0, + "bos_token_id": 2, + "eos_token_id": 1, + } + + config["image_token_index"] = 257152 + text_config = { + "vocab_size": 257152, + "num_hidden_layers": 18, + "num_key_value_heads": 1, + "head_dim": 256, + "torch_dtype": precision, + "hidden_size": 1024, + "hidden_activation": "gelu_pytorch_tanh", + "num_attention_heads": 8, + "intermediate_size": 4096, + "is_encoder_decoder": False, + } + final_config = GemmaConfig() + final_config.update(text_config) + return final_config diff --git a/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py new file mode 100644 index 000000000..f85437a51 --- /dev/null +++ b/lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py @@ -0,0 +1,423 @@ +""" +Convert pi0 parameters from Jax to Pytorch + +Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment +and install the required librairies. + +```bash +cd ~/code/openpi +source .venv/bin/activate +``` + +Example downloading parameters: +```bash +python +>>> import openpi.shared.download as download +>>> path='s3://openpi-assets/checkpoints/pi0_base/params' +>>> download.maybe_download(path) +``` + +Converting pi0_base: +```python +python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ + --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \ + --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch +``` + +```python +python lerobot/common/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py \ + --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \ + --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch +``` +""" + +import argparse +import pathlib + +import jax +import numpy as np +import orbax.checkpoint as ocp +import torch +from jax.sharding import SingleDeviceSharding + +from lerobot.common.policies.pi0.configuration_pi0 import PI0Config +from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import ( + get_gemma_config, + get_paligemma_config, +) +from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy + +PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} + + +def slice_paligemma_state_dict(state_dict, config): + suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" + + # fmt: off + # patch embeddings + state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( + 3, 2, 0, 1 + ) + state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") + # positional embeddings + state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( + -1, config.vision_config.hidden_size + ) + + # extract vision layers to be sliced at index 0. There are 27 layers in the base model. + encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") + encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") + encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") + encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") + + encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") + encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") + encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") + encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") + + encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") + encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") + encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") + encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") + encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") + encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") + encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") + encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}") + + for i in range(config.vision_config.num_hidden_layers): + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] + + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() + state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) + + state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() + state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}") + + # multimodal projector + + state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() + state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}") + + # text decoder (gemma) + embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") + state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") + llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") + llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") + + llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") + llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") + llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") + + for i in range(config.text_config.num_hidden_layers): + # llm_attention_q_einsum[i].shape = (8, 2048, 256) + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) + + state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") + state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied. + + # fmt: on + expert_dict = {} + final_state_dict = {} + for key, value in state_dict.items(): + if key not in [ + f"llm/final_norm_1/scale{suffix}", + f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", + f"llm/layers/attn/kv_einsum_1/w{suffix}", + f"llm/layers/attn/q_einsum_1/w{suffix}", + f"llm/layers/mlp_1/gating_einsum{suffix}", + f"llm/layers/mlp_1/linear{suffix}", + f"llm/layers/pre_attention_norm_1/scale{suffix}", + f"llm/layers/pre_ffw_norm_1/scale{suffix}", + ]: + final_state_dict[key] = torch.from_numpy(value) + else: + expert_dict[key] = value + + return final_state_dict, expert_dict + + +def slice_gemma_state_dict(state_dict, config, num_expert=1): + # fmt: off + # text decoder (gemma) + # no embedding vector, the expert just has the decoder layers + + embedding_vector = torch.zeros([config.vocab_size, config.hidden_size]) + state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector + + # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. + + suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" + + llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") + llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") + llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") + + llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") + llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") + # TODO verify correctness of layer norm loading + + llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") + llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") + + for i in range(config.num_hidden_layers): + q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size) + + state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped + + k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() + state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped + v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() + state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped + + # output projection. + + # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024) + o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0) + + state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped + # mlp layers + gate_proj_weight = llm_mlp_gating_einsum[i, 0] + state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() + up_proj_weight = llm_mlp_gating_einsum[i, 1] + state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() + state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() + state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] + state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] + + state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}") + state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here) + + # fmt: on + final_state_dict = {} + for key, value in state_dict.items(): + if not isinstance(value, torch.Tensor): + final_state_dict[key] = torch.from_numpy(value) + else: + final_state_dict[key] = value + return final_state_dict + + +def flatten_for_memory(tree, parent_key=""): + out = {} + for k, v in tree.items(): + new_key = f"{parent_key}/{k}" if parent_key else k + if isinstance(v, dict): + out.update(flatten_for_memory(v, new_key)) + else: + out[new_key] = np.array(v) # Ensure conversion to np.array for consistency + return out + + +def flatten_for_npz(tree, parent_key=""): + out = {} + for k, v in tree.items(): + new_key = f"{parent_key}/{k}" if parent_key else k + if isinstance(v, dict): + out.update(flatten_for_npz(v, new_key)) + else: + # bf16/f32 here? + out[new_key] = np.array(v) + return out + + +def slice_initial_orbax_checkpoint(checkpoint_dir: str): + params_path = pathlib.Path(checkpoint_dir).resolve() + checkpointer = ocp.PyTreeCheckpointer() + + metadata = checkpointer.metadata(params_path) + print("Metadata keys:", list(metadata.keys())) + + params_name = "params" + + item = {params_name: metadata[params_name]} + device = jax.local_devices()[0] # Use the first local device + sharding = SingleDeviceSharding(device) + restored = checkpointer.restore( + params_path, + ocp.args.PyTreeRestore( + item=item, + restore_args=jax.tree_util.tree_map( + lambda _: ocp.ArrayRestoreArgs( + restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it + sharding=sharding, + ), + item, + ), + transforms={}, + ), + ) + params = restored[params_name] + + # get params for PaliGemma + pali_params = params["PaliGemma"] + del params["PaliGemma"] + pali_params_flat = flatten_for_npz(pali_params) + return {"paligemma_params": pali_params_flat, "projection_params": params} + + +def update_keys_with_prefix(d: dict, prefix: str) -> dict: + """Update dictionary keys by adding a prefix.""" + return {f"{prefix}{key}": value for key, value in d.items()} + + +def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): + # Break down orbax ckpts - they are in OCDBT + initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) + # process projection params + keys = [ + "state_proj", + "action_in_proj", + "action_out_proj", + "action_time_mlp_in", + "action_time_mlp_out", + ] + + projection_params = {} + for key in keys: + kernel_params = initial_params["projection_params"][key]["kernel"] + bias_params = initial_params["projection_params"][key]["bias"] + if isinstance(kernel_params, dict): + weight = kernel_params["value"] + bias = bias_params["value"] + else: + weight = kernel_params + bias = bias_params + projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T + projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias)) + + # Process PaliGemma weights + paligemma_config = get_paligemma_config(precision) + paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict( + initial_params["paligemma_params"], paligemma_config + ) + + # Process Gemma weights (at this stage they are unused) + gemma_config = get_gemma_config(precision) + gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config) + + # Instantiate model from configs + + if "pi0_aloha_sim" in checkpoint_dir: + pi0_config = PI0Config( + empty_cameras=2, + adapt_to_pi_aloha=True, + use_delta_joint_actions_aloha=False, + ) + elif "pi0_aloha_towel" in checkpoint_dir: + pi0_config = PI0Config( + adapt_to_pi_aloha=True, + use_delta_joint_actions_aloha=True, + ) + elif "pi0_base" in checkpoint_dir: + pi0_config = PI0Config( + empty_cameras=0, + adapt_to_pi_aloha=False, + use_delta_joint_actions_aloha=False, + ) + else: + raise ValueError() + + # gemma_config=gemma_config, paligemma_config=paligemma_config) + pi0_model = PI0Policy(pi0_config) + + paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.") + gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") + projection_params = update_keys_with_prefix(projection_params, "model.") + + # load state dict + torch_dtype = PRECISIONS[precision] + pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params}) + pi0_model = pi0_model.to(torch_dtype) + # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + + pi0_model.save_pretrained(output_path, safe_serialization=True) + # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype) + + # assert that model loads properly + del pi0_model + PI0Policy.from_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_dir", + default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params", + type=str, + help="Path to the ocdbt checkpoint", + ) + + parser.add_argument( + "--precision", + choices=["float32", "bfloat16", "float16"], + default="float32", + type=str, + help="Precision identifier for model conversion - should match the base checkpoint precision.", + ) + # tokenizer is identical to paligemma, it appears + + parser.add_argument( + "--tokenizer_hub_id", + default="google/paligemma-3b-pt-224", + type=str, + help="Hub path to the tokenizer to save", + ) + + parser.add_argument( + "--output_path", + required=True, + type=str, + help="Path to save converted weights to", + ) + + args = parser.parse_args() + convert_pi0_checkpoint( + checkpoint_dir=args.checkpoint_dir, + precision=args.precision, + tokenizer_id=args.tokenizer_hub_id, + output_path=args.output_path, + ) diff --git a/lerobot/common/policies/pi0/flex_attention.py b/lerobot/common/policies/pi0/flex_attention.py new file mode 100644 index 000000000..38a5b5976 --- /dev/null +++ b/lerobot/common/policies/pi0/flex_attention.py @@ -0,0 +1,127 @@ +import torch +import torch.nn.functional as F # noqa: N812 +from packaging.version import Version + +if Version(torch.__version__) > Version("2.5.0"): + # Ffex attention is only available from torch 2.5 onwards + from torch.nn.attention.flex_attention import ( + _mask_mod_signature, + _round_up_to_multiple, + create_block_mask, + create_mask, + flex_attention, + ) + + +# @torch.compile(dynamic=False) +def flex_attention_forward( + attention_mask: torch.Tensor, + batch_size: int, + head_dim: int, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + scaling=None, +): + """ + This is defined out of classes to make compile happy. + """ + + original_dtype = query_states.dtype + num_att_heads = 8 + num_key_value_heads = 1 + num_key_value_groups = num_att_heads // num_key_value_heads + + key_states = key_states[:, :, :, None, :] + key_states = key_states.expand( + batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :] + value_states = value_states.expand( + batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + query_states = query_states.to(torch.float32) + key_states = key_states.to(torch.float32) + value_states = value_states.to(torch.float32) + + causal_mask = attention_mask + if causal_mask is not None: + causal_mask = causal_mask[:, None, :, : key_states.shape[2]] + + if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: + causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1) + + def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: + def mask_mod(b, h, q_idx, kv_idx): + # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs. + return precomputed_mask[b][h][q_idx][kv_idx] + + return mask_mod + + b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask + + block_size = 128 + q_len_rounded = _round_up_to_multiple(q_len, block_size) + kv_len_rounded = _round_up_to_multiple(kv_len, block_size) + + # *CRITICAL* we do need to expand here, else we get a CUDA index error + + pad_q = q_len_rounded - q_len + pad_k = kv_len_rounded - kv_len + + padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) + mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) + + mask_4d = create_mask( + mod_fn=mask_mod_fn_orig, + B=b_mask, + H=h_mask, + Q_LEN=q_len_rounded, + KV_LEN=kv_len_rounded, + device=causal_mask.device, + _compile=False, + ) + + mask_mod_fn_padded = precomputed_mask_factory(mask_4d) + block_mask = create_block_mask( + mask_mod=mask_mod_fn_padded, + B=b_mask, + H=h_mask, + Q_LEN=q_len_rounded, + KV_LEN=kv_len_rounded, + BLOCK_SIZE=block_size, + device=causal_mask.device, + _compile=False, + ) + + # mask is applied inside the kernel, ideally more efficiently than score_mod. + attn_output, attention_weights = flex_attention( + query_states, + key_states, + value_states, + block_mask=block_mask, + enable_gqa=True, # because we shaped query/key states for GQA + scale=head_dim**-0.5 if scaling is None else scaling, + return_lse=True, + ) + + attn_output = attn_output.to(dtype=original_dtype) + attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim] + attn_output = attn_output.reshape( + batch_size, + -1, + attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] + ) + return attn_output diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py new file mode 100644 index 000000000..90d1a14c9 --- /dev/null +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -0,0 +1,732 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +π0: A Vision-Language-Action Flow Model for General Robot Control + +[Paper](https://www.physicalintelligence.company/download/pi0.pdf) +[Jax code](https://github.com/Physical-Intelligence/openpi) + +Designed by Physical Intelligence. Ported from Jax by Hugging Face. + +Install pi0 extra dependencies: +```bash +pip install -e ".[pi0]" +``` + +Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): +```bash +python lerobot/scripts/train.py \ +--policy.path=lerobot/pi0 \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of finetuning the pi0 neural network with PaliGemma and expert Gemma +pretrained with VLM default parameters before pi0 finetuning: +```bash +python lerobot/scripts/train.py \ +--policy.type=pi0 \ +--dataset.repo_id=danaaubakirova/koch_test +``` + +Example of using the pi0 pretrained model outside LeRobot training framework: +```python +policy = Pi0Policy.from_pretrained("lerobot/pi0") +``` + +""" + +import math +from collections import deque + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from transformers import AutoTokenizer + +from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pi0.configuration_pi0 import PI0Config +from lerobot.common.policies.pi0.paligemma_with_expert import ( + PaliGemmaWithExpertConfig, + PaliGemmaWithExpertModel, +) +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.utils.utils import get_safe_dtype + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb + + +def sample_beta(alpha, beta, bsize, device): + gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) + gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) + return gamma1 / (gamma1 + gamma2) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + return att_2d_masks + + +def resize_with_pad(img, width, height, pad_value=-1): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +def pad_vector(vector, new_dim): + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with pi0 which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class PI0Policy(PreTrainedPolicy): + """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot.""" + + config_class = PI0Config + name = "pi0" + + def __init__( + self, + config: PI0Config, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224") + self.model = PI0FlowMatching(config) + + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + def get_optim_params(self) -> dict: + return self.parameters() + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + + batch = self.normalize_inputs(batch) + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, state, noise=noise + ) + + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: + """Do a full training forward pass to compute the loss""" + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + actions = self.prepare_action(batch) + actions_is_pad = batch.get("actions_id_pad") + + loss_dict = {} + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + loss_dict["losses_after_forward"] = losses.clone() + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + losses = losses * in_episode_bound.unsqueeze(-1) + loss_dict["losses_after_in_ep_bound"] = losses.clone() + + # Remove padding + losses = losses[:, :, : self.config.max_action_dim] + loss_dict["losses_after_rm_padding"] = losses.clone() + + loss = losses.mean() + # For backward pass + loss_dict["loss"] = loss + # For logging + loss_dict["l2_loss"] = loss.item() + return loss_dict + + def prepare_images(self, batch): + """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and + convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. + """ + images = [] + img_masks = [] + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expacted by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + mask = torch.ones(bsize, dtype=torch.bool, device=device) + images.append(img) + img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + img = torch.ones_like(img) * -1 + mask = torch.zeros_like(mask) + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_language(self, batch) -> tuple[Tensor, Tensor]: + """Tokenize the text input""" + device = batch[OBS_ROBOT].device + tasks = batch["task"] + + # PaliGemma prompt has to end with a new line + tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + + tokenized_prompt = self.language_tokenizer.__call__( + tasks, + padding="max_length", + padding_side="right", + max_length=self.config.tokenizer_max_length, + return_tensors="pt", + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + def prepare_state(self, batch): + """Pad state""" + state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim) + return state + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + +class PI0FlowMatching(nn.Module): + """ + π0: A Vision-Language-Action Flow Model for General Robot Control + + [Paper](https://www.physicalintelligence.company/download/pi0.pdf) + [Jax code](https://github.com/Physical-Intelligence/openpi) + + Designed by Physical Intelligence. Ported from Jax by Hugging Face. + ┌──────────────────────────────┐ + │ actions │ + │ ▲ │ + │ ┌┴─────┐ │ + │ kv cache │Gemma │ │ + │ ┌──────────►│Expert│ │ + │ │ │ │ │ + │ ┌┴────────┐ │x 10 │ │ + │ │ │ └▲──▲──┘ │ + │ │PaliGemma│ │ │ │ + │ │ │ │ robot state │ + │ │ │ noise │ + │ └▲──▲─────┘ │ + │ │ │ │ + │ │ image(s) │ + │ language tokens │ + └──────────────────────────────┘ + """ + + def __init__(self, config): + super().__init__() + self.config = config + + paligemma_with_export_config = PaliGemmaWithExpertConfig( + freeze_vision_encoder=self.config.freeze_vision_encoder, + train_expert_only=self.config.train_expert_only, + attention_implementation=self.config.attention_implementation, + ) + self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) + + # Projections are float32 + self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) + self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) + + self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width) + self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) + + self.set_requires_grad() + + def set_requires_grad(self): + for params in self.state_proj.parameters(): + params.requires_grad = self.config.train_state_proj + + def sample_noise(self, shape, device): + noise = torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + return noise + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for PaliGemma transformer processing. + """ + # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty + embs = [] + pad_masks = [] + att_masks = [] + + # TODO: remove for loop + for ( + img, + img_mask, + ) in zip(images, img_masks, strict=False): + img_emb = self.paligemma_with_expert.embed_image(img) + img_emb = img_emb.to(dtype=torch.bfloat16) + + # Normalize image embeddings + img_emb_dim = img_emb.shape[-1] + img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + # Create attention masks so that image tokens attend to each other + att_masks += [0] * num_img_embs + + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + + # Normalize language embeddings + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + # full attention between image and language inputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def embed_suffix(self, state, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Embed state + state_emb = self.state_proj(state) + state_emb = state_emb.to(dtype=torch.bfloat16) + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + dtype = state_emb.dtype + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] + + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device + ) + time_emb = time_emb.type(dtype=dtype) + + # Fuse timestep + action information using an MLP + action_emb = self.action_in_proj(noisy_actions) + + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] + ([0] * (self.config.n_action_steps - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + + return embs, pad_masks, att_masks + + def forward( + self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + ) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + fill_kv_cache=False, + ) + suffix_out = suffix_out[:, -self.config.n_action_steps :] + # Original openpi code, upcast attention output + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + + losses = F.mse_loss(u_t, v_t, reduction="none") + return losses + + def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = state.shape[0] + device = state.device + + if noise is None: + actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim) + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Compute image and language key value cache + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=self.config.use_cache, + fill_kv_cache=True, + ) + + dt = -1.0 / self.config.num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + + # Euler step + x_t += dt * v_t + time += dt + return x_t + + def denoise_step( + self, + state, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=self.config.use_cache, + fill_kv_cache=False, + ) + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.n_action_steps :] + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + return v_t diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py new file mode 100644 index 000000000..08c36c11f --- /dev/null +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -0,0 +1,403 @@ +from typing import List, Optional, Union + +import torch +import torch.version +from pytest import Cache +from torch import nn +from transformers import ( + AutoConfig, + GemmaForCausalLM, + PaliGemmaForConditionalGeneration, + PretrainedConfig, + PreTrainedModel, +) +from transformers.models.auto import CONFIG_MAPPING + +from lerobot.common.policies.pi0.flex_attention import flex_attention_forward + + +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + +class PaliGemmaWithExpertConfig(PretrainedConfig): + model_type = "PaliGemmaWithExpertModel" + sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig} + + def __init__( + self, + paligemma_config: dict | None = None, + gemma_expert_config: dict | None = None, + freeze_vision_encoder: bool = True, + train_expert_only: bool = True, + attention_implementation: str = "eager", + **kwargs, + ): + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + self.attention_implementation = attention_implementation + + if paligemma_config is None: + # Default config from Pi0 + self.paligemma_config = CONFIG_MAPPING["paligemma"]( + transformers_version="4.48.1", + _vocab_size=257152, + bos_token_id=2, + eos_token_id=1, + hidden_size=2048, + image_token_index=257152, + model_type="paligemma", + pad_token_id=0, + projection_dim=2048, + text_config={ + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 2048, + "intermediate_size": 16384, + "model_type": "gemma", + "num_attention_heads": 8, + "num_hidden_layers": 18, + "num_image_tokens": 256, + "num_key_value_heads": 1, + "torch_dtype": "float32", + "vocab_size": 257152, + }, + vision_config={ + "hidden_size": 1152, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "num_image_tokens": 256, + "patch_size": 14, + "projection_dim": 2048, + "projector_hidden_act": "gelu_fast", + "torch_dtype": "float32", + "vision_use_head": False, + }, + ) + elif isinstance(self.paligemma_config, dict): + # Override Pi0 default config for PaliGemma + if "model_type" not in gemma_expert_config: + paligemma_config["model_type"] = "paligemma" + + cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] + self.paligemma_config = cfg_cls(**paligemma_config) + + if gemma_expert_config is None: + # Default config from Pi0 + self.gemma_expert_config = CONFIG_MAPPING["gemma"]( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=2, + eos_token_id=1, + head_dim=256, + hidden_act="gelu_pytorch_tanh", + hidden_activation="gelu_pytorch_tanh", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=4096, + max_position_embeddings=8192, + model_type="gemma", + num_attention_heads=8, + num_hidden_layers=18, + num_key_value_heads=1, + pad_token_id=0, + rms_norm_eps=1e-06, + rope_theta=10000.0, + torch_dtype="float32", + transformers_version="4.48.1", + use_cache=True, + vocab_size=257152, + ) + elif isinstance(self.gemma_expert_config, dict): + # Override Pi0 default config for Gemma Expert + if "model_type" not in gemma_expert_config: + gemma_expert_config["model_type"] = "gemma" + + cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] + self.gemma_expert_config = cfg_cls(**gemma_expert_config) + + super().__init__(**kwargs) + + def __post_init__(self): + super().__post_init__() + if self.train_expert_only and not self.freeze_vision_encoder: + raise ValueError( + "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." + ) + + if self.attention_implementation not in ["eager", "fa2", "flex"]: + raise ValueError( + f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." + ) + + +class PaliGemmaWithExpertModel(PreTrainedModel): + config_class = PaliGemmaWithExpertConfig + + def __init__(self, config: PaliGemmaWithExpertConfig): + super().__init__(config=config) + self.config = config + self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) + self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) + # Remove unused embed_tokens + self.gemma_expert.model.embed_tokens = None + + self.to_bfloat16_like_physical_intelligence() + self.set_requires_grad() + + def set_requires_grad(self): + if self.config.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + for params in self.paligemma.vision_tower.parameters(): + params.requires_grad = False + + if self.config.train_expert_only: + self.paligemma.eval() + for params in self.paligemma.parameters(): + params.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + + if self.config.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + + if self.config.train_expert_only: + self.paligemma.eval() + + def to_bfloat16_like_physical_intelligence(self): + self.paligemma = self.paligemma.to(dtype=torch.bfloat16) + + params_to_change_dtype = [ + "language_model.model.layers", + "gemma_expert.model.layers", + "vision_tower", + "multi_modal", + ] + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_change_dtype): + param.data = param.data.to(dtype=torch.bfloat16) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.model.embed_tokens(tokens) + + # TODO: break down this huge forward into modules or functions + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, + inputs_embeds: List[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + fill_kv_cache: Optional[bool] = None, + ): + models = [self.paligemma.language_model.model, self.gemma_expert.model] + + for hidden_states in inputs_embeds: + # TODO this is very inefficient + # dtype is always the same, batch size too (if > 1 len) + # device could be trickier in multi gpu edge cases but that's it + if hidden_states is None: + continue + batch_size = hidden_states.shape[0] + + # RMSNorm + num_layers = self.paligemma.config.text_config.num_hidden_layers + head_dim = self.paligemma.config.text_config.head_dim + for layer_idx in range(num_layers): + query_states = [] + key_states = [] + value_states = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is None: + continue + layer = models[i].layers[layer_idx] + # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) + # hidden_states = hidden_states * normalizer + hidden_states = layer.input_layernorm(hidden_states) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=torch.bfloat16) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + + query_states = apply_rope(query_states, position_ids) + key_states = apply_rope(key_states, position_ids) + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) + value_states = torch.cat( + [past_key_values[layer_idx]["value_states"], value_states], dim=1 + ) + + attention_interface = self.get_attention_interface() + att_output = attention_interface( + attention_mask, batch_size, head_dim, query_states, key_states, value_states + ) + att_output = att_output.to(dtype=torch.bfloat16) + + # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + + if hidden_states is not None: + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start:end]) + + # TODO: first dropout (by default 0.0) + + # first residual + out_emb += hidden_states + after_first_residual = out_emb.clone() + + out_emb = layer.post_attention_layernorm(out_emb) + out_emb = layer.mlp(out_emb) + + # TODO: second dropout (by default 0.0) + + # second residual + out_emb += after_first_residual + + outputs_embeds.append(out_emb) + + start = end + else: + outputs_embeds.append(None) + + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is not None: + out_emb = models[i].norm(hidden_states) + outputs_embeds.append(out_emb) + else: + outputs_embeds.append(None) + + return outputs_embeds, past_key_values + + def get_attention_interface(self): + if self.config.attention_implementation == "fa2": + attention_interface = self.flash_attention_forward + elif self.config.attention_implementation == "flex": + attention_interface = flex_attention_forward + else: + attention_interface = self.eager_attention_forward + return attention_interface + + def flash_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + raise NotImplementedError("FA2 is not implemented (yet)") + + def eager_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + num_att_heads = self.config.paligemma_config.text_config.num_attention_heads + num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads + num_key_value_groups = num_att_heads // num_key_value_heads + + # query_states: batch_size, sequence_length, num_att_head, head_dim + # key_states: batch_size, sequence_length, num_key_value_head, head_dim + # value_states: batch_size, sequence_length, num_key_value_head, head_dim + sequence_length = key_states.shape[1] + + key_states = key_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + # Attention here is upcasted to float32 to match the original eager implementation. + + query_states = query_states.to(dtype=torch.float32) + key_states = key_states.to(dtype=torch.float32) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + att_weights *= head_dim**-0.5 + big_neg = -2.3819763e38 # See gemma/modules.py + + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + + probs = nn.functional.softmax(masked_att_weights, dim=-1) + probs = probs.to(dtype=value_states.dtype) + + # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length + # value_states: batch_size, sequence_length, num_att_heads, head_dim + + att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) + + att_output = att_output.permute(0, 2, 1, 3) + # we use -1 because sequence length can change + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + + return att_output diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 519bc2200..6366a5a48 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -319,7 +319,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: # (b, t) -> (t, b) for key in batch: - if batch[key].ndim > 1: + if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) action = batch["action"] # (t, b, action_dim) @@ -502,7 +502,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: # Undo (b, t) -> (t, b). for key in batch: - if batch[key].ndim > 1: + if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) return info diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index fda3edd9f..cb4f1874c 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -74,6 +74,18 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: return device +def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): + """ + mps is currently not compatible with float64 + """ + if isinstance(device, torch.device): + device = device.type + if device == "mps" and dtype == torch.float64: + return torch.float32 + else: + return dtype + + def is_torch_device_available(try_device: str) -> bool: if try_device == "cuda": return torch.cuda.is_available() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index acc3d9f35..253bc45ca 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -125,6 +125,9 @@ def rollout( # Reset the policy and environments. policy.reset() + if hasattr(policy, "use_ema_modules"): + policy.use_ema_modules() + observation, info = env.reset(seed=seeds) if render_callback is not None: render_callback(env) @@ -205,6 +208,9 @@ def rollout( stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) ret["observation"] = stacked_observations + if hasattr(policy, "use_original_modules"): + policy.use_original_modules() + return ret @@ -235,7 +241,9 @@ def eval_policy( raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") if not isinstance(policy, PreTrainedPolicy): - raise ValueError(policy) + raise ValueError( + f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." + ) start = time.time() policy.eval() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 954c79699..9af1a9720 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -38,6 +38,7 @@ from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, + get_safe_dtype, get_safe_torch_device, has_method, init_logging, @@ -86,6 +87,10 @@ def update_policy( optimizer.zero_grad() + if hasattr(policy, "update_ema_modules"): + policy.update_ema_modules() + + # Step through pytorch scheduler at every batch instead of epoch if lr_scheduler is not None: lr_scheduler.step() @@ -215,6 +220,7 @@ def train(cfg: TrainPipelineConfig): device=device, ds_meta=offline_dataset.meta, ) + logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device, enabled=cfg.use_amp) @@ -296,6 +302,10 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): dl_iter = cycle(dataloader) policy.train() + + if hasattr(policy, "init_ema_modules"): + policy.init_ema_modules() + offline_step = 0 for _ in range(step, cfg.offline.steps): if offline_step == 0: @@ -306,7 +316,8 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): dataloading_s = time.perf_counter() - start_time for key in batch: - batch[key] = batch[key].to(device, non_blocking=True) + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(device, non_blocking=True) train_info = update_policy( policy, @@ -365,6 +376,8 @@ def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): "next.reward": {"shape": (), "dtype": np.dtype("float32")}, "next.done": {"shape": (), "dtype": np.dtype("?")}, "task_index": {"shape": (), "dtype": np.dtype("int64")}, + # FIXME: 'task' is a string + # "task": {"shape": (), "dtype": np.dtype("?")}, # FIXME: 'next.success' is expected by pusht env but not xarm "next.success": {"shape": (), "dtype": np.dtype("?")}, }, @@ -451,9 +464,10 @@ def sample_trajectory_and_update_buffer(): if len(offline_dataset.meta.tasks) > 1: raise NotImplementedError("Add support for multi task.") - # Hack to add a task to the online_dataset (0 is the first task of the offline_dataset) + # TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset) total_num_frames = eval_info["episodes"]["index"].shape[0] eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64) + eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames with lock if lock is not None else nullcontext(): start_update_buffer_time = time.perf_counter() @@ -499,7 +513,9 @@ def sample_trajectory_and_update_buffer(): dataloading_s = time.perf_counter() - start_time for key in batch: - batch[key] = batch[key].to(device, non_blocking=True) + if isinstance(batch[key], torch.Tensor): + dtype = get_safe_dtype(batch[key].dtype, device) + batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True) train_info = update_policy( policy, diff --git a/lerobot/templates/visualize_dataset_homepage.html b/lerobot/templates/visualize_dataset_homepage.html index adff07be7..19613afb5 100644 --- a/lerobot/templates/visualize_dataset_homepage.html +++ b/lerobot/templates/visualize_dataset_homepage.html @@ -7,7 +7,7 @@ -Example Datasets:
- -
- \ No newline at end of file + diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 3c93d2d62..08de3e3dd 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -107,8 +107,8 @@

filter videos
🔽
- -