backup wip
This commit is contained in:
@@ -130,9 +130,9 @@ class DiffusionPolicy(nn.Module):
|
||||
|
||||
def _generate_actions(self, batch):
|
||||
if not self.training and self.ema_diffusion is not None:
|
||||
return self.ema_diffusion.predict_action(batch)
|
||||
return self.ema_diffusion.generate_actions(batch)
|
||||
else:
|
||||
return self.diffusion.predict_action(batch)
|
||||
return self.diffusion.generate_actions(batch)
|
||||
|
||||
def update(self, batch, **_):
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
|
||||
Reference in New Issue
Block a user