From 64b5920e948a9f8ddfac9f7a039096a1d7299753 Mon Sep 17 00:00:00 2001 From: Cadene Date: Sat, 24 Feb 2024 18:19:18 +0000 Subject: [PATCH] format --- lerobot/common/tdmpc.py | 2 +- lerobot/scripts/eval.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py index 902673a..f813327 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/tdmpc.py @@ -51,7 +51,7 @@ class TOLD(nn.Module): """Predicts next latent state (d) and single-step reward (R).""" x = torch.cat([z, a], dim=-1) return self._dynamics(x), self._reward(x) - + def next_dynamics(self, z, a): """Predicts next latent state (d).""" x = torch.cat([z, a], dim=-1) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 62d482e..12c5d14 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -1,3 +1,4 @@ +import threading from pathlib import Path import hydra @@ -11,11 +12,12 @@ from torchrl.envs import EnvBase from lerobot.common.envs.factory import make_env from lerobot.common.tdmpc import TDMPC from lerobot.common.utils import set_seed -import threading + def write_video(video_path, stacked_frames, fps): imageio.mimsave(video_path, stacked_frames, fps=fps) + def eval_policy( env: EnvBase, policy: TensorDictModule = None,