backup wip
This commit is contained in:
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.envs import EnvBase, SerialEnv
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.batched_envs import BatchedEnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
@@ -131,14 +131,7 @@ def eval(cfg: dict, out_dir=None):
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=make_env,
|
||||
create_env_kwargs=[
|
||||
{"cfg": cfg, "seed": env_seed, "transform": offline_buffer.transform}
|
||||
for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
|
||||
Reference in New Issue
Block a user