backup wip

This commit is contained in:
Alexander Soare
2024-03-19 16:02:09 +00:00
parent 88347965c2
commit ea17f4ce50
11 changed files with 71 additions and 46 deletions

View File

@@ -9,7 +9,7 @@ import numpy as np
import torch
import tqdm
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase, SerialEnv
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
@@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)