WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -36,111 +36,196 @@ from datetime import datetime as dt
from pathlib import Path
import einops
import gymnasium as gym
import hydra
import imageio
import numpy as np
import torch
import tqdm
from huggingface_hub import snapshot_download
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.transforms import apply_inverse_transform
def write_video(video_path, stacked_frames, fps):
imageio.mimsave(video_path, stacked_frames, fps=fps)
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy
obs = {
"observation.image": torch.from_numpy(observation["pixels"]).float(),
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
}
# convert to (b c h w) torch format
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs
def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
return action
def eval_policy(
env: BatchedEnvBase,
policy: AbstractPolicy,
num_episodes: int = 10,
max_steps: int = 30,
env: gym.vector.VectorEnv,
policy,
save_video: bool = False,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
):
if policy is not None:
policy.eval()
start = time.time()
sum_rewards = []
max_rewards = []
successes = []
all_successes = []
seeds = []
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
num_episodes = len(env.envs)
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
# needed as I'm currently taking a ceil.
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
ep_frames = []
ep_frames = []
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
def maybe_render_frame(env):
if save_video: # noqa: B023
if return_first_video:
visu = env.envs[0].render()
visu = visu[None, ...] # add batch dim
else:
visu = np.stack([env.render() for env in env.envs])
ep_frames.append(visu) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
if policy is not None:
policy.clear_action_queue()
for _ in range(num_episodes):
seeds.append("TODO")
if env.is_closed:
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
seeds.extend(env._next_seed)
if hasattr(policy, "reset"):
policy.reset()
else:
logging.warning(
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
)
# reset the environment
observation, info = env.reset(seed=cfg.seed)
maybe_render_frame(env)
rewards = []
successes = []
dones = []
done = torch.tensor([False for _ in env.envs])
step = 0
do_rollout = True
while do_rollout:
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
# send observation to device/gpu
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
# get the next action for the environment
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
break_when_any_done=env.batch_size[0] == 1,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())
action = policy.select_action(observation, step)
if save_video or (return_first_video and i == 0):
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
batch_stacked_frames = batch_stacked_frames.transpose(
1, 0, *range(2, batch_stacked_frames.ndim)
) # (b, t, *)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
if save_video:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
# apply the next
observation, reward, terminated, truncated, info = env.step(action)
maybe_render_frame(env)
if return_first_video and i == 0:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
reward = torch.from_numpy(reward)
terminated = torch.from_numpy(terminated)
truncated = torch.from_numpy(truncated)
# environment is considered done (no more steps), when success state is reached (terminated is True),
# or time limit is reached (truncated is True), or it was previsouly done.
done = terminated | truncated | done
if "final_info" in info:
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
success = [
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
]
else:
success = [False for _ in env.envs]
success = torch.tensor(success)
rewards.append(reward)
dones.append(done)
successes.append(success)
step += 1
if done.all():
do_rollout = False
break
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
expand_done_indices = done_indices[:, None].expand(-1, step)
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
env.close()
if save_video or return_first_video:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
if save_video:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
for thread in threads:
thread.join()
@@ -158,16 +243,16 @@ def eval_policy(
zip(
sum_rewards[:num_episodes],
max_rewards[:num_episodes],
successes[:num_episodes],
all_successes[:num_episodes],
seeds[:num_episodes],
strict=True,
)
)
],
"aggregated": {
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
@@ -194,21 +279,13 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
dataset = make_dataset(cfg, stats_path=stats_path)
logging.info("Making environment.")
env = make_env(cfg, transform=offline_buffer.transform)
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
else:
# when policy is None, rollout a random policy
policy = None
# when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
info = eval_policy(
env,
@@ -216,8 +293,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
# TODO(rcadene): what should we do with the transform?
transform=dataset.transform,
)
print(info["aggregated"])