Fixes following #670 (#719)

This commit is contained in:
Simon Alibert
2025-02-12 12:53:55 +01:00
committed by GitHub
parent 90e099b39f
commit e71095960f
3 changed files with 8 additions and 7 deletions

View File

@@ -300,7 +300,7 @@ class PI0Policy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
@@ -328,12 +328,12 @@ class PI0Policy(PreTrainedPolicy):
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
loss = losses.mean()
# For backward pass
loss_dict["loss"] = loss
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss_dict
return loss, loss_dict
def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and

View File

@@ -102,7 +102,7 @@ class WandBLogger:
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
if mode not in {"train", "eval"}:
raise ValueError(mode)
for k, v in d.items():
@@ -114,7 +114,7 @@ class WandBLogger:
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
if mode not in {"train", "eval"}:
raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")