fix train.py, stats, eval.py (training is running)
This commit is contained in:
@@ -243,10 +243,9 @@ class DiffusionUnetImagePolicy(BaseImagePolicy):
|
||||
result = {"action": action, "action_pred": action_pred}
|
||||
return result
|
||||
|
||||
def compute_loss(self, batch):
|
||||
assert "valid_mask" not in batch
|
||||
nobs = batch["obs"]
|
||||
nactions = batch["action"]
|
||||
def compute_loss(self, obs_dict, action):
|
||||
nobs = obs_dict
|
||||
nactions = action
|
||||
batch_size = nactions.shape[0]
|
||||
horizon = nactions.shape[1]
|
||||
|
||||
|
||||
@@ -157,7 +157,8 @@ class DiffusionPolicy(nn.Module):
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
loss = self.diffusion.compute_loss(obs_dict)
|
||||
action = batch["action"]
|
||||
loss = self.diffusion.compute_loss(obs_dict, action)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
||||
Reference in New Issue
Block a user