backup wip

This commit is contained in:
Alexander Soare
2024-04-11 18:33:54 +01:00
parent 94cc22da9e
commit 5666ec3ec7
3 changed files with 53 additions and 89 deletions

View File

@@ -130,9 +130,9 @@ class DiffusionPolicy(nn.Module):
def _generate_actions(self, batch):
if not self.training and self.ema_diffusion is not None:
return self.ema_diffusion.predict_action(batch)
return self.ema_diffusion.generate_actions(batch)
else:
return self.diffusion.predict_action(batch)
return self.diffusion.generate_actions(batch)
def update(self, batch, **_):
"""Run the model in train mode, compute the loss, and do an optimization step."""