forked from tangger/lerobot
wip: still needs batch logic for act and tdmp
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
from torchrl.envs import SerialEnv
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
@@ -148,6 +149,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
env = SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=make_env,
|
||||
create_env_kwargs=[
|
||||
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
||||
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
@@ -191,7 +200,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
env,
|
||||
td_policy,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
return_first_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
save_video=True,
|
||||
|
||||
Reference in New Issue
Block a user