WIP
WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
@@ -36,111 +36,196 @@ from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from huggingface_hub import snapshot_download
|
||||
from tensordict.nn import TensorDictModule
|
||||
from torchrl.envs import EnvBase
|
||||
from torchrl.envs.batched_envs import BatchedEnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.transforms import apply_inverse_transform
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
def preprocess_observation(observation, transform=None):
|
||||
# map to expected inputs for the policy
|
||||
obs = {
|
||||
"observation.image": torch.from_numpy(observation["pixels"]).float(),
|
||||
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
|
||||
}
|
||||
# convert to (b c h w) torch format
|
||||
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
|
||||
|
||||
# apply same transforms as in training
|
||||
if transform is not None:
|
||||
for key in obs:
|
||||
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
|
||||
|
||||
return obs
|
||||
|
||||
|
||||
def postprocess_action(action, transform=None):
|
||||
action = action.to("cpu")
|
||||
# action is a batch (num_env,action_dim) instead of an item (action_dim),
|
||||
# we assume applying inverse transform on a batch works the same
|
||||
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
|
||||
assert (
|
||||
action.ndim == 2
|
||||
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
|
||||
return action
|
||||
|
||||
|
||||
def eval_policy(
|
||||
env: BatchedEnvBase,
|
||||
policy: AbstractPolicy,
|
||||
num_episodes: int = 10,
|
||||
max_steps: int = 30,
|
||||
env: gym.vector.VectorEnv,
|
||||
policy,
|
||||
save_video: bool = False,
|
||||
video_dir: Path = None,
|
||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
transform: callable = None,
|
||||
):
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
start = time.time()
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
all_successes = []
|
||||
seeds = []
|
||||
threads = [] # for video saving threads
|
||||
episode_counter = 0 # for saving the correct number of videos
|
||||
|
||||
num_episodes = len(env.envs)
|
||||
|
||||
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
|
||||
# needed as I'm currently taking a ceil.
|
||||
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
|
||||
ep_frames = []
|
||||
ep_frames = []
|
||||
|
||||
def maybe_render_frame(env: EnvBase, _):
|
||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
def maybe_render_frame(env):
|
||||
if save_video: # noqa: B023
|
||||
if return_first_video:
|
||||
visu = env.envs[0].render()
|
||||
visu = visu[None, ...] # add batch dim
|
||||
else:
|
||||
visu = np.stack([env.render() for env in env.envs])
|
||||
ep_frames.append(visu) # noqa: B023
|
||||
|
||||
# Clear the policy's action queue before the start of a new rollout.
|
||||
if policy is not None:
|
||||
policy.clear_action_queue()
|
||||
for _ in range(num_episodes):
|
||||
seeds.append("TODO")
|
||||
|
||||
if env.is_closed:
|
||||
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
|
||||
seeds.extend(env._next_seed)
|
||||
if hasattr(policy, "reset"):
|
||||
policy.reset()
|
||||
else:
|
||||
logging.warning(
|
||||
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
|
||||
)
|
||||
|
||||
# reset the environment
|
||||
observation, info = env.reset(seed=cfg.seed)
|
||||
maybe_render_frame(env)
|
||||
|
||||
rewards = []
|
||||
successes = []
|
||||
dones = []
|
||||
|
||||
done = torch.tensor([False for _ in env.envs])
|
||||
step = 0
|
||||
do_rollout = True
|
||||
while do_rollout:
|
||||
# apply transform to normalize the observations
|
||||
observation = preprocess_observation(observation, transform)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
|
||||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
|
||||
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
|
||||
rollout = env.rollout(
|
||||
max_steps=max_steps,
|
||||
policy=policy,
|
||||
auto_cast_to_device=True,
|
||||
callback=maybe_render_frame,
|
||||
break_when_any_done=env.batch_size[0] == 1,
|
||||
)
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
# this won't be included).
|
||||
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
rollout_steps = rollout["next", "done"].shape[1]
|
||||
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
|
||||
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
|
||||
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
|
||||
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
|
||||
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
successes.extend(batch_success.tolist())
|
||||
action = policy.select_action(observation, step)
|
||||
|
||||
if save_video or (return_first_video and i == 0):
|
||||
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
|
||||
batch_stacked_frames = batch_stacked_frames.transpose(
|
||||
1, 0, *range(2, batch_stacked_frames.ndim)
|
||||
) # (b, t, *)
|
||||
# apply inverse transform to unnormalize the action
|
||||
action = postprocess_action(action, transform)
|
||||
|
||||
if save_video:
|
||||
for stacked_frames, done_index in zip(
|
||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||
):
|
||||
if episode_counter >= num_episodes:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
# apply the next
|
||||
observation, reward, terminated, truncated, info = env.step(action)
|
||||
maybe_render_frame(env)
|
||||
|
||||
if return_first_video and i == 0:
|
||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
|
||||
reward = torch.from_numpy(reward)
|
||||
terminated = torch.from_numpy(terminated)
|
||||
truncated = torch.from_numpy(truncated)
|
||||
# environment is considered done (no more steps), when success state is reached (terminated is True),
|
||||
# or time limit is reached (truncated is True), or it was previsouly done.
|
||||
done = terminated | truncated | done
|
||||
|
||||
if "final_info" in info:
|
||||
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
|
||||
success = [
|
||||
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
|
||||
]
|
||||
else:
|
||||
success = [False for _ in env.envs]
|
||||
success = torch.tensor(success)
|
||||
|
||||
rewards.append(reward)
|
||||
dones.append(done)
|
||||
successes.append(success)
|
||||
|
||||
step += 1
|
||||
|
||||
if done.all():
|
||||
do_rollout = False
|
||||
break
|
||||
|
||||
rewards = torch.stack(rewards, dim=1)
|
||||
successes = torch.stack(successes, dim=1)
|
||||
dones = torch.stack(dones, dim=1)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
# this won't be included).
|
||||
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
|
||||
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
|
||||
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
|
||||
expand_done_indices = done_indices[:, None].expand(-1, step)
|
||||
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
|
||||
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
|
||||
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
|
||||
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
|
||||
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
|
||||
sum_rewards.extend(batch_sum_reward.tolist())
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
all_successes.extend(batch_success.tolist())
|
||||
|
||||
env.close()
|
||||
|
||||
if save_video or return_first_video:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
|
||||
if save_video:
|
||||
for stacked_frames, done_index in zip(
|
||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||
):
|
||||
if episode_counter >= num_episodes:
|
||||
continue
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(str(video_path), stacked_frames[:done_index], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
episode_counter += 1
|
||||
|
||||
if return_first_video:
|
||||
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
@@ -158,16 +243,16 @@ def eval_policy(
|
||||
zip(
|
||||
sum_rewards[:num_episodes],
|
||||
max_rewards[:num_episodes],
|
||||
successes[:num_episodes],
|
||||
all_successes[:num_episodes],
|
||||
seeds[:num_episodes],
|
||||
strict=True,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
|
||||
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
|
||||
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
|
||||
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
|
||||
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
|
||||
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
@@ -194,21 +279,13 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
|
||||
logging.info("Making transforms.")
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
||||
dataset = make_dataset(cfg, stats_path=stats_path)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
else:
|
||||
# when policy is None, rollout a random policy
|
||||
policy = None
|
||||
# when policy is None, rollout a random policy
|
||||
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
|
||||
|
||||
info = eval_policy(
|
||||
env,
|
||||
@@ -216,8 +293,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
save_video=True,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
fps=cfg.env.fps,
|
||||
max_steps=cfg.env.episode_length,
|
||||
num_episodes=cfg.eval_episodes,
|
||||
# TODO(rcadene): what should we do with the transform?
|
||||
transform=dataset.transform,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user