Remove loss masking from diffusion policy (#135)

This commit is contained in:
Alexander Soare
2024-05-06 07:27:01 +01:00
committed by GitHub
parent f5e76393eb
commit a8e245fb31
5 changed files with 10 additions and 1 deletions

View File

@@ -268,7 +268,7 @@ class DiffusionModel(nn.Module):
loss = F.mse_loss(pred, target, reduction="none")
# Mask loss wherever the action is padded with copies (edges of the dataset trajectory).
if "action_is_pad" in batch:
if self.config.do_mask_loss_for_padding and "action_is_pad" in batch:
in_episode_bound = ~batch["action_is_pad"]
loss = loss * in_episode_bound.unsqueeze(-1)