diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 34b4ea33e..923af4537 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -48,10 +48,10 @@ class AbstractPolicy(nn.Module, PyTorchModelHubMixin): """One step of the policy's learning algorithm.""" raise NotImplementedError("Abstract method") - def save(self, fp): + def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin torch.save(self.state_dict(), fp) - def load(self, fp): + def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin d = torch.load(fp) self.load_state_dict(d)