diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 90d1a14c9..c8b12cafc 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -300,7 +300,7 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) - self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) @@ -328,12 +328,12 @@ def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, losses = losses[:, :, : self.config.max_action_dim] loss_dict["losses_after_rm_padding"] = losses.clone() - loss = losses.mean() # For backward pass - loss_dict["loss"] = loss + loss = losses.mean() # For logging loss_dict["l2_loss"] = loss.item() - return loss_dict + + return loss, loss_dict def prepare_images(self, batch): """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 2ab3c3fd0..9985b894c 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -102,7 +102,7 @@ def log_policy(self, checkpoint_dir: Path): self._wandb.log_artifact(artifact) def log_dict(self, d: dict, step: int, mode: str = "train"): - if mode in {"train", "eval"}: + if mode not in {"train", "eval"}: raise ValueError(mode) for k, v in d.items(): @@ -114,7 +114,7 @@ def log_dict(self, d: dict, step: int, mode: str = "train"): self._wandb.log({f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): - if mode in {"train", "eval"}: + if mode not in {"train", "eval"}: raise ValueError(mode) wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a840b33d0..7a31f2f5f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -233,7 +233,7 @@ def train(cfg: TrainPipelineConfig): logging.info(train_tracker) if wandb_logger: wandb_log_dict = {**train_tracker.to_dict(), **output_dict} - wandb_logger.log_dict(wandb_log_dict) + wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() if cfg.save_checkpoint and is_saving_step: @@ -271,6 +271,7 @@ def train(cfg: TrainPipelineConfig): logging.info(eval_tracker) if wandb_logger: wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_dict(wandb_log_dict, step, mode="eval") wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") if eval_env: