Merge pull request #6 from Cadene/user/rcadene/2024_03_04_diffusion

Make diffusion work
This commit is contained in:
Remi
2024-03-04 18:30:40 +01:00
committed by GitHub
12 changed files with 276 additions and 142 deletions

View File

@@ -118,7 +118,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = make_env(cfg, transform=offline_buffer._transform)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
@@ -137,7 +137,7 @@ def eval(cfg: dict, out_dir=None):
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
num_episodes=cfg.eval_episodes,
)
print(metrics)

View File

@@ -123,7 +123,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
@@ -153,6 +152,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_policy")
policy = make_policy(cfg)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
td_policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
@@ -162,6 +164,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
logging.info(f"{offline_buffer.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
step = 0 # number of policy update
is_offline = True
@@ -174,19 +186,23 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
td_policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint model at step {step}")
logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1
@@ -200,11 +216,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
policy=td_policy,
auto_cast_to_device=True,
)
assert len(rollout) <= cfg.env.episode_length
assert len(rollout) <= cfg.env.episode_length // cfg.n_action_steps
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)
@@ -231,19 +247,23 @@ def train(cfg: dict, out_dir=None, job_name=None):
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
if step > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
td_policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
return_first_video=True,
)
log_eval_info(logger, eval_info, step, cfg, offline_buffer, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
if step > 0 and cfg.save_model and step % cfg.save_freq == 0:
logging.info(f"Checkpoint model at step {step}")
logging.info(f"Checkpoint policy at step {step}")
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step += 1
online_step += 1