This commit is contained in:
Thomas Wolf
2024-04-02 22:49:16 +02:00
parent 4751642ace
commit 24821fee24

View File

@@ -48,10 +48,10 @@ class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
d = torch.load(fp)
self.load_state_dict(d)