forked from tangger/lerobot
Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)
This commit is contained in:
@@ -26,31 +26,31 @@ def eval_policy(
|
||||
save_video: bool = False,
|
||||
video_dir: Path = None,
|
||||
fps: int = 15,
|
||||
env_step: int = None,
|
||||
wandb=None,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
if wandb is not None:
|
||||
assert env_step is not None
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
threads = []
|
||||
for i in range(num_episodes):
|
||||
ep_frames = []
|
||||
|
||||
def rendering_callback(env, td=None):
|
||||
ep_frames.append(env.render())
|
||||
|
||||
tensordict = env.reset()
|
||||
if save_video or wandb:
|
||||
|
||||
ep_frames = []
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def rendering_callback(env, td=None):
|
||||
ep_frames.append(env.render())
|
||||
|
||||
# render first frame before rollout
|
||||
rendering_callback(env)
|
||||
else:
|
||||
rendering_callback = None
|
||||
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
callback=rendering_callback if save_video or wandb else None,
|
||||
callback=rendering_callback,
|
||||
auto_reset=False,
|
||||
tensordict=tensordict,
|
||||
auto_cast_to_device=True,
|
||||
@@ -63,7 +63,7 @@ def eval_policy(
|
||||
max_rewards.append(ep_max_reward.item())
|
||||
successes.append(ep_success.item())
|
||||
|
||||
if save_video or wandb:
|
||||
if save_video or (return_first_video and i == 0):
|
||||
stacked_frames = np.stack(ep_frames)
|
||||
|
||||
if save_video:
|
||||
@@ -76,12 +76,8 @@ def eval_policy(
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
first_episode = i == 0
|
||||
if wandb and first_episode:
|
||||
eval_video = wandb.Video(
|
||||
stacked_frames.transpose(0, 3, 1, 2), fps=fps, format="mp4"
|
||||
)
|
||||
wandb.log({"eval_video": eval_video}, step=env_step)
|
||||
if return_first_video and i == 0:
|
||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
@@ -91,6 +87,8 @@ def eval_policy(
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
}
|
||||
if return_first_video:
|
||||
return metrics, first_video
|
||||
return metrics
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user