Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@@ -1,3 +1,4 @@
import logging
import threading
import time
from pathlib import Path
@@ -10,6 +11,7 @@ import tqdm
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.policies.factory import make_policy
@@ -112,7 +114,11 @@ def eval(cfg: dict, out_dir=None):
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)