fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
loss = self.diffusion.compute_loss(obs_dict)
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(