fix train.py, stats, eval.py (training is running)

This commit is contained in:
Cadene
2024-04-05 09:31:39 +00:00
parent c93ce35d8c
commit 5af00d0c1e
11 changed files with 76 additions and 72 deletions

View File

@@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
result = {"action": action, "action_pred": action_pred}
return result
def compute_loss(self, batch):
assert "valid_mask" not in batch
nobs = batch["obs"]
nactions = batch["action"]
def compute_loss(self, obs_dict, action):
nobs = obs_dict
nactions = action
batch_size = nactions.shape[0]
horizon = nactions.shape[1]

View File

@@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
"image": batch["observation.image"],
"agent_pos": batch["observation.state"],
}
loss = self.diffusion.compute_loss(obs_dict)
action = batch["action"]
loss = self.diffusion.compute_loss(obs_dict, action)
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(