remove policy is None eval end-to-end tests

This commit is contained in:
Cadene
2024-04-10 15:09:04 +00:00
parent 2186429fa8
commit 8866b22db1
2 changed files with 4 additions and 36 deletions

View File

@@ -57,7 +57,7 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: gym.vector.VectorEnv,
policy,
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
@@ -312,12 +312,12 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
# when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
logging.info("Making policy.")
policy = make_policy(cfg)
info = eval_policy(
env,
policy=policy,
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
# TODO(rcadene): what should we do with the transform?