Refactor eval.py (#127)
This commit is contained in:
@@ -269,7 +269,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
offline_dataset = make_dataset(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
|
||||
eval_env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||
@@ -337,15 +337,16 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
if step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
eval_info = eval_policy(
|
||||
env,
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
max_episodes_rendered=4,
|
||||
seed=cfg.seed,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["videos"][0], step, mode="eval")
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
@@ -395,7 +396,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
step += 1
|
||||
|
||||
# create an env dedicated to online episodes collection from policy rollout
|
||||
rollout_env = make_env(cfg, num_parallel_envs=1)
|
||||
online_training_env = make_env(cfg, n_envs=1)
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
@@ -427,10 +428,12 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
eval_info = eval_policy(
|
||||
rollout_env,
|
||||
online_training_env,
|
||||
policy,
|
||||
n_episodes=1,
|
||||
return_episode_data=True,
|
||||
seed=cfg.seed,
|
||||
start_seed=cfg.training.online_env_seed,
|
||||
enable_progbar=True,
|
||||
)
|
||||
|
||||
add_episodes_inplace(
|
||||
@@ -461,6 +464,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
eval_env.close()
|
||||
online_training_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user