This commit is contained in:
Cadene
2024-02-24 18:19:18 +00:00
parent aed02dc7c6
commit 64b5920e94
2 changed files with 4 additions and 2 deletions

View File

@@ -51,7 +51,7 @@ class TOLD(nn.Module):
"""Predicts next latent state (d) and single-step reward (R)."""
x = torch.cat([z, a], dim=-1)
return self._dynamics(x), self._reward(x)
def next_dynamics(self, z, a):
"""Predicts next latent state (d)."""
x = torch.cat([z, a], dim=-1)

View File

@@ -1,3 +1,4 @@
import threading
from pathlib import Path
import hydra
@@ -11,11 +12,12 @@ from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed
import threading
def write_video(video_path, stacked_frames, fps):
imageio.mimsave(video_path, stacked_frames, fps=fps)
def eval_policy(
env: EnvBase,
policy: TensorDictModule = None,