Eval reproduced! Train running (but not reproduced)

This commit is contained in:
Cadene
2024-02-10 15:46:24 +00:00
parent 937b2f8cba
commit 228c045674
14 changed files with 787 additions and 118 deletions

View File

@@ -32,23 +32,25 @@ def eval_policy(
ep_frames.append(frame)
tensordict = env.reset()
# render first frame before rollout
rendering_callback(env)
if save_video:
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback,
callback=rendering_callback if save_video else None,
auto_reset=False,
tensordict=tensordict,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
successes.append(ep_success.item())
if save_video:
video_dir.parent.mkdir(parents=True, exist_ok=True)
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
@@ -82,8 +84,8 @@ def eval(cfg: dict):
metrics = eval_policy(
env,
policy=policy,
num_episodes=10,
save_video=True,
num_episodes=20,
save_video=False,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(metrics)