format
This commit is contained in:
@@ -51,7 +51,7 @@ class TOLD(nn.Module):
|
|||||||
"""Predicts next latent state (d) and single-step reward (R)."""
|
"""Predicts next latent state (d) and single-step reward (R)."""
|
||||||
x = torch.cat([z, a], dim=-1)
|
x = torch.cat([z, a], dim=-1)
|
||||||
return self._dynamics(x), self._reward(x)
|
return self._dynamics(x), self._reward(x)
|
||||||
|
|
||||||
def next_dynamics(self, z, a):
|
def next_dynamics(self, z, a):
|
||||||
"""Predicts next latent state (d)."""
|
"""Predicts next latent state (d)."""
|
||||||
x = torch.cat([z, a], dim=-1)
|
x = torch.cat([z, a], dim=-1)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
@@ -11,11 +12,12 @@ from torchrl.envs import EnvBase
|
|||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.tdmpc import TDMPC
|
from lerobot.common.tdmpc import TDMPC
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_seed
|
||||||
import threading
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env: EnvBase,
|
env: EnvBase,
|
||||||
policy: TensorDictModule = None,
|
policy: TensorDictModule = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user