forked from tangger/lerobot
Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
@@ -10,6 +11,7 @@ import tqdm
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.envs import EnvBase
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
@@ -112,7 +114,11 @@ def eval(cfg: dict, out_dir=None):
|
||||
set_seed(cfg.seed)
|
||||
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
||||
|
||||
env = make_env(cfg)
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
|
||||
@@ -117,21 +117,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
@@ -151,8 +140,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
online_buffer = TensorDictReplayBuffer(
|
||||
storage=LazyMemmapStorage(100_000),
|
||||
sampler=online_sampler,
|
||||
transform=offline_buffer._transform,
|
||||
)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer._transform)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg, transform=offline_buffer._transform)
|
||||
|
||||
td_policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
step = 0 # number of policy update
|
||||
|
||||
Reference in New Issue
Block a user