wip: still needs batch logic for act and tdmp
This commit is contained in:
@@ -9,7 +9,8 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs import EnvBase, SerialEnv
|
||||
from torchrl.envs.batched_envs import BatchedEnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
@@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps):
|
||||
|
||||
|
||||
def eval_policy(
|
||||
env: EnvBase,
|
||||
env: BatchedEnvBase,
|
||||
policy: TensorDictModule = None,
|
||||
num_episodes: int = 10,
|
||||
max_steps: int = 30,
|
||||
@@ -36,45 +37,55 @@ def eval_policy(
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
threads = []
|
||||
for i in tqdm.tqdm(range(num_episodes)):
|
||||
threads = [] # for video saving threads
|
||||
episode_counter = 0 # for saving the correct number of videos
|
||||
|
||||
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
||||
# needed as I'm currently taking a ceil.
|
||||
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
|
||||
ep_frames = []
|
||||
if save_video or (return_first_video and i == 0):
|
||||
|
||||
def render_frame(env):
|
||||
def maybe_render_frame(env: EnvBase, _):
|
||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
env.register_rendering_hook(render_frame)
|
||||
|
||||
with torch.inference_mode():
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
auto_cast_to_device=True,
|
||||
callback=maybe_render_frame,
|
||||
)
|
||||
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
|
||||
ep_sum_reward = rollout["next", "reward"].sum()
|
||||
ep_max_reward = rollout["next", "reward"].max()
|
||||
ep_success = rollout["next", "success"].any()
|
||||
sum_rewards.append(ep_sum_reward.item())
|
||||
max_rewards.append(ep_max_reward.item())
|
||||
successes.append(ep_success.item())
|
||||
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
|
||||
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
|
||||
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
successes.extend(batch_success.tolist())
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
stacked_frames = np.stack(ep_frames)
|
||||
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
|
||||
batch_stacked_frames = batch_stacked_frames.transpose(
|
||||
1, 0, *range(2, batch_stacked_frames.ndim)
|
||||
) # (b, t, *)
|
||||
|
||||
if save_video:
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{i}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames, fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for stacked_frames in batch_stacked_frames:
|
||||
if episode_counter >= num_episodes:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames, fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
|
||||
if return_first_video and i == 0:
|
||||
first_video = stacked_frames.transpose(0, 3, 1, 2)
|
||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||
|
||||
env.reset_rendering_hooks()
|
||||
|
||||
@@ -82,9 +93,9 @@ def eval_policy(
|
||||
thread.join()
|
||||
|
||||
info = {
|
||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
||||
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
||||
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
}
|
||||
@@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None):
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
env = SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=make_env,
|
||||
create_env_kwargs=[
|
||||
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
|
||||
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
@@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None):
|
||||
save_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
max_steps=cfg.env.episode_length // cfg.n_action_steps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
)
|
||||
print(metrics)
|
||||
|
||||
Reference in New Issue
Block a user