wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare
2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions

View File

@@ -7,6 +7,7 @@ import torch
from tensordict.nn import TensorDictModule
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from torchrl.envs import SerialEnv
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
@@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
logging.info("make_policy")
policy = make_policy(cfg)
@@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env,
td_policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,