This commit is contained in:
Cadene
2024-02-24 18:19:18 +00:00
parent aed02dc7c6
commit 64b5920e94
2 changed files with 4 additions and 2 deletions

View File

@@ -51,7 +51,7 @@ class TOLD(nn.Module):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def next_dynamics(self, z, a):
"""Predicts next latent state (d)."""
x = torch.cat([z, a], dim=-1)