Add Normalize, non_blocking=True in tdmpc, tdmpc run (TODO: diffusion)

This commit is contained in:
Remi Cadene
2024-03-02 15:53:29 +00:00
parent b5a2f460ea
commit 1ae6205269
9 changed files with 183 additions and 67 deletions

View File

@@ -117,21 +117,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info("make_env")
env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(cfg)
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
@@ -151,8 +140,22 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
transform=offline_buffer._transform,
)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
logging.info("make_policy")
policy = make_policy(cfg, transform=offline_buffer._transform)
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
step = 0 # number of policy update