diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py index 902673aa..f8133279 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/tdmpc.py @@ -51,7 +51,7 @@ class TOLD(nn.Module): """Predicts next latent state (d) and single-step reward (R).""" x = torch.cat([z, a], dim=-1) return self._dynamics(x), self._reward(x) - + def next_dynamics(self, z, a): """Predicts next latent state (d).""" x = torch.cat([z, a], dim=-1) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 62d482e0..12c5d14b 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,3 +1,4 @@ +import threading from pathlib import Path import hydra @@ -11,11 +12,12 @@ from torchrl.envs import EnvBase from lerobot.common.envs.factory import make_env from lerobot.common.tdmpc import TDMPC from lerobot.common.utils import set_seed -import threading + def write_video(video_path, stacked_frames, fps): imageio.mimsave(video_path, stacked_frames, fps=fps) + def eval_policy( env: EnvBase, policy: TensorDictModule = None,