Refactor eval.py (#127)

This commit is contained in:
Alexander Soare
2024-05-03 17:33:16 +01:00
committed by GitHub
parent b7b69fcc3d
commit bccee745c3
12 changed files with 457 additions and 298 deletions

View File

@@ -269,7 +269,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
offline_dataset = make_dataset(cfg)
logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
eval_env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
@@ -337,15 +337,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info = eval_policy(
env,
eval_env,
policy,
cfg.eval.n_episodes,
video_dir=Path(out_dir) / "eval",
max_episodes_rendered=4,
seed=cfg.seed,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(eval_info["videos"][0], step, mode="eval")
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_model and step % cfg.training.save_freq == 0:
@@ -395,7 +396,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
# create an env dedicated to online episodes collection from policy rollout
rollout_env = make_env(cfg, num_parallel_envs=1)
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
@@ -427,10 +428,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
policy.eval()
with torch.no_grad():
eval_info = eval_policy(
rollout_env,
online_training_env,
policy,
n_episodes=1,
return_episode_data=True,
seed=cfg.seed,
start_seed=cfg.training.online_env_seed,
enable_progbar=True,
)
add_episodes_inplace(
@@ -461,6 +464,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
online_step += 1
eval_env.close()
online_training_env.close()
logging.info("End of training")