fix: action_is_pad was missing in compute_loss
This commit is contained in:
@@ -153,12 +153,7 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
action = batch["action"]
|
||||
loss = self.diffusion.compute_loss(obs_dict, action)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
||||
Reference in New Issue
Block a user