Skip to content

Commit

Permalink
Fixes following #670 (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts authored Feb 12, 2025
1 parent 90e099b commit e710959
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
8 changes: 4 additions & 4 deletions lerobot/common/policies/pi0/modeling_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()

def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
Expand Down Expand Up @@ -328,12 +328,12 @@ def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str,
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()

loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss_dict

return loss, loss_dict

def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
Expand Down
4 changes: 2 additions & 2 deletions lerobot/common/utils/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def log_policy(self, checkpoint_dir: Path):
self._wandb.log_artifact(artifact)

def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
if mode not in {"train", "eval"}:
raise ValueError(mode)

for k, v in d.items():
Expand All @@ -114,7 +114,7 @@ def log_dict(self, d: dict, step: int, mode: str = "train"):
self._wandb.log({f"{mode}/{k}": v}, step=step)

def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
if mode not in {"train", "eval"}:
raise ValueError(mode)

wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
Expand Down
3 changes: 2 additions & 1 deletion lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
wandb_logger.log_dict(wandb_log_dict)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()

if cfg.save_checkpoint and is_saving_step:
Expand Down Expand Up @@ -271,6 +271,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")

if eval_env:
Expand Down

0 comments on commit e710959

Please sign in to comment.