Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)

This commit is contained in:
Cadene
2024-02-26 01:10:09 +00:00
parent 5a219fed6e
commit 21670dce90
12 changed files with 306 additions and 443 deletions

View File

@@ -26,31 +26,31 @@ def eval_policy(
save_video: bool = False,
video_dir: Path = None,
fps: int = 15,
env_step: int = None,
wandb=None,
return_first_video: bool = False,
):
if wandb is not None:
assert env_step is not None
sum_rewards = []
max_rewards = []
successes = []
threads = []
for i in range(num_episodes):
ep_frames = []
def rendering_callback(env, td=None):
ep_frames.append(env.render())
tensordict = env.reset()
if save_video or wandb:
ep_frames = []
if save_video or (return_first_video and i == 0):
def rendering_callback(env, td=None):
ep_frames.append(env.render())
# render first frame before rollout
rendering_callback(env)
else:
rendering_callback = None
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback if save_video or wandb else None,
callback=rendering_callback,
auto_reset=False,
tensordict=tensordict,
auto_cast_to_device=True,
@@ -63,7 +63,7 @@ def eval_policy(
max_rewards.append(ep_max_reward.item())
successes.append(ep_success.item())
if save_video or wandb:
if save_video or (return_first_video and i == 0):
stacked_frames = np.stack(ep_frames)
if save_video:
@@ -76,12 +76,8 @@ def eval_policy(
thread.start()
threads.append(thread)
first_episode = i == 0
if wandb and first_episode:
eval_video = wandb.Video(
stacked_frames.transpose(0, 3, 1, 2), fps=fps, format="mp4"
)
wandb.log({"eval_video": eval_video}, step=env_step)
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
for thread in threads:
thread.join()
@@ -91,6 +87,8 @@ def eval_policy(
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
}
if return_first_video:
return metrics, first_video
return metrics