format
This commit is contained in:
@@ -51,7 +51,7 @@ class TOLD(nn.Module):
|
||||
"""Predicts next latent state (d) and single-step reward (R)."""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x), self._reward(x)
|
||||
|
||||
|
||||
def next_dynamics(self, z, a):
|
||||
"""Predicts next latent state (d)."""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
@@ -11,11 +12,12 @@ from torchrl.envs import EnvBase
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.tdmpc import TDMPC
|
||||
from lerobot.common.utils import set_seed
|
||||
import threading
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
def eval_policy(
|
||||
env: EnvBase,
|
||||
policy: TensorDictModule = None,
|
||||
|
||||
Reference in New Issue
Block a user