wip: still needs batch logic for act and tdmp

This commit is contained in:
Alexander Soare
2024-03-14 15:22:55 +00:00
parent 8c56770318
commit ba91976944
11 changed files with 240 additions and 100 deletions

View File

@@ -9,7 +9,8 @@ import numpy as np
import torch
import tqdm
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs import EnvBase, SerialEnv
from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
@@ -23,7 +24,7 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: EnvBase,
env: BatchedEnvBase,
policy: TensorDictModule = None,
num_episodes: int = 10,
max_steps: int = 30,
@@ -36,45 +37,55 @@ def eval_policy(
sum_rewards = []
max_rewards = []
successes = []
threads = []
for i in tqdm.tqdm(range(num_episodes)):
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
# needed as I'm currently taking a ceil.
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
ep_frames = []
if save_video or (return_first_video and i == 0):
def render_frame(env):
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
env.register_rendering_hook(render_frame)
with torch.inference_mode():
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
sum_rewards.append(ep_sum_reward.item())
max_rewards.append(ep_max_reward.item())
successes.append(ep_success.item())
batch_sum_reward = rollout["next", "reward"].flatten(start_dim=1).sum(dim=-1)
batch_max_reward = rollout["next", "reward"].flatten(start_dim=1).max(dim=-1)[0]
batch_success = rollout["next", "success"].flatten(start_dim=1).any(dim=-1)
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())
if save_video or (return_first_video and i == 0):
stacked_frames = np.stack(ep_frames)
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
batch_stacked_frames = batch_stacked_frames.transpose(
1, 0, *range(2, batch_stacked_frames.ndim)
) # (b, t, *)
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{i}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
for stacked_frames in batch_stacked_frames:
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames, fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video and i == 0:
first_video = stacked_frames.transpose(0, 3, 1, 2)
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
env.reset_rendering_hooks()
@@ -82,9 +93,9 @@ def eval_policy(
thread.join()
info = {
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
}
@@ -119,7 +130,14 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = SerialEnv(
cfg.rollout_batch_size,
create_env_fn=make_env,
create_env_kwargs=[
{"cfg": cfg, "seed": s, "transform": offline_buffer.transform}
for s in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
],
)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
@@ -138,7 +156,7 @@ def eval(cfg: dict, out_dir=None):
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length // cfg.n_action_steps,
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
)
print(metrics)