fix: action_is_pad was missing in compute_loss

This commit is contained in:
Cadene
2024-04-05 11:33:39 +00:00
parent ad3379a73a
commit a420714ee4
2 changed files with 13 additions and 10 deletions

View File

@@ -153,12 +153,7 @@ class DiffusionPolicy(nn.Module):
data_s = time.time() - start_time
obs_dict = {
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss = self.diffusion.compute_loss(batch)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(