WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -36,111 +36,196 @@ from datetime import datetime as dt
from pathlib import Path
import einops
import gymnasium as gym
import hydra
import imageio
import numpy as np
import torch
import tqdm
from huggingface_hub import snapshot_download
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.transforms import apply_inverse_transform
def write_video(video_path, stacked_frames, fps):
imageio.mimsave(video_path, stacked_frames, fps=fps)
def preprocess_observation(observation, transform=None):
# map to expected inputs for the policy
obs = {
"observation.image": torch.from_numpy(observation["pixels"]).float(),
"observation.state": torch.from_numpy(observation["agent_pos"]).float(),
}
# convert to (b c h w) torch format
obs["observation.image"] = einops.rearrange(obs["observation.image"], "b h w c -> b c h w")
# apply same transforms as in training
if transform is not None:
for key in obs:
obs[key] = torch.stack([transform({key: item})[key] for item in obs[key]])
return obs
def postprocess_action(action, transform=None):
action = action.to("cpu")
# action is a batch (num_env,action_dim) instead of an item (action_dim),
# we assume applying inverse transform on a batch works the same
action = apply_inverse_transform({"action": action}, transform)["action"].numpy()
assert (
action.ndim == 2
), "we assume dimensions are respectively the number of parallel envs, action dimensions"
return action
def eval_policy(
env: BatchedEnvBase,
policy: AbstractPolicy,
num_episodes: int = 10,
max_steps: int = 30,
env: gym.vector.VectorEnv,
policy,
save_video: bool = False,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
):
if policy is not None:
policy.eval()
start = time.time()
sum_rewards = []
max_rewards = []
successes = []
all_successes = []
seeds = []
threads = [] # for video saving threads
episode_counter = 0 # for saving the correct number of videos
num_episodes = len(env.envs)
# TODO(alexander-soare): if num_episodes is not evenly divisible by the batch size, this will do more work than
# needed as I'm currently taking a ceil.
for i in tqdm.tqdm(range(-(-num_episodes // env.batch_size[0]))):
ep_frames = []
ep_frames = []
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
ep_frames.append(env.render()) # noqa: B023
def maybe_render_frame(env):
if save_video: # noqa: B023
if return_first_video:
visu = env.envs[0].render()
visu = visu[None, ...] # add batch dim
else:
visu = np.stack([env.render() for env in env.envs])
ep_frames.append(visu) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
if policy is not None:
policy.clear_action_queue()
for _ in range(num_episodes):
seeds.append("TODO")
if env.is_closed:
env.start() # needed to be able to get the seeds the first time as BatchedEnvs are lazy
seeds.extend(env._next_seed)
if hasattr(policy, "reset"):
policy.reset()
else:
logging.warning(
f"Policy {policy} doesnt have a `reset` method. This find if the policy doesnt rely on an internal state during rollout."
)
# reset the environment
observation, info = env.reset(seed=cfg.seed)
maybe_render_frame(env)
rewards = []
successes = []
dones = []
done = torch.tensor([False for _ in env.envs])
step = 0
do_rollout = True
while do_rollout:
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
# send observation to device/gpu
observation = {key: observation[key].to(cfg.device, non_blocking=True) for key in observation}
# get the next action for the environment
with torch.inference_mode():
# TODO(alexander-soare): When `break_when_any_done == False` this rolls out for max_steps even when all
# envs are done the first time. But we only use the first rollout. This is a waste of compute.
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
auto_cast_to_device=True,
callback=maybe_render_frame,
break_when_any_done=env.batch_size[0] == 1,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
done_indices = torch.argmax(rollout["next", "done"].to(int), axis=1) # (batch_size, rollout_steps)
mask = (torch.arange(rollout_steps) <= done_indices).unsqueeze(-1) # (batch_size, rollout_steps, 1)
batch_sum_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "sum")
batch_max_reward = einops.reduce((rollout["next", "reward"] * mask), "b n 1 -> b", "max")
batch_success = einops.reduce((rollout["next", "success"] * mask), "b n 1 -> b", "any")
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
successes.extend(batch_success.tolist())
action = policy.select_action(observation, step)
if save_video or (return_first_video and i == 0):
batch_stacked_frames = np.stack(ep_frames) # (t, b, *)
batch_stacked_frames = batch_stacked_frames.transpose(
1, 0, *range(2, batch_stacked_frames.ndim)
) # (b, t, *)
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
if save_video:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
# apply the next
observation, reward, terminated, truncated, info = env.step(action)
maybe_render_frame(env)
if return_first_video and i == 0:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
reward = torch.from_numpy(reward)
terminated = torch.from_numpy(terminated)
truncated = torch.from_numpy(truncated)
# environment is considered done (no more steps), when success state is reached (terminated is True),
# or time limit is reached (truncated is True), or it was previsouly done.
done = terminated | truncated | done
if "final_info" in info:
# VectorEnv stores is_success into `info["final_info"][env_id]["is_success"]` instead of `info["is_success"]`
success = [
env_info["is_success"] if env_info is not None else False for env_info in info["final_info"]
]
else:
success = [False for _ in env.envs]
success = torch.tensor(success)
rewards.append(reward)
dones.append(done)
successes.append(success)
step += 1
if done.all():
do_rollout = False
break
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
done_indices = torch.argmax(dones.to(int), axis=1) # (batch_size, rollout_steps)
expand_done_indices = done_indices[:, None].expand(-1, step)
expand_step_indices = torch.arange(step)[None, :].expand(num_episodes, -1)
mask = (expand_step_indices <= expand_done_indices).int() # (batch_size, rollout_steps)
batch_sum_reward = einops.reduce((rewards * mask), "b n -> b", "sum")
batch_max_reward = einops.reduce((rewards * mask), "b n -> b", "max")
batch_success = einops.reduce((successes * mask), "b n -> b", "any")
sum_rewards.extend(batch_sum_reward.tolist())
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
env.close()
if save_video or return_first_video:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
if save_video:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
for thread in threads:
thread.join()
@@ -158,16 +243,16 @@ def eval_policy(
zip(
sum_rewards[:num_episodes],
max_rewards[:num_episodes],
successes[:num_episodes],
all_successes[:num_episodes],
seeds[:num_episodes],
strict=True,
)
)
],
"aggregated": {
"avg_sum_reward": np.nanmean(sum_rewards[:num_episodes]),
"avg_max_reward": np.nanmean(max_rewards[:num_episodes]),
"pc_success": np.nanmean(successes[:num_episodes]) * 100,
"avg_sum_reward": float(np.nanmean(sum_rewards[:num_episodes])),
"avg_max_reward": float(np.nanmean(max_rewards[:num_episodes])),
"pc_success": float(np.nanmean(all_successes[:num_episodes]) * 100),
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
@@ -194,21 +279,13 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
dataset = make_dataset(cfg, stats_path=stats_path)
logging.info("Making environment.")
env = make_env(cfg, transform=offline_buffer.transform)
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
else:
# when policy is None, rollout a random policy
policy = None
# when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
info = eval_policy(
env,
@@ -216,8 +293,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
save_video=True,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
max_steps=cfg.env.episode_length,
num_episodes=cfg.eval_episodes,
# TODO(rcadene): what should we do with the transform?
transform=dataset.transform,
)
print(info["aggregated"])

View File

@@ -1,14 +1,12 @@
import logging
from itertools import cycle
from pathlib import Path
import hydra
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
@@ -34,7 +32,7 @@ def train_notebook(out_dir=None, job_name=None, config_name="default", config_pa
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
def log_train_info(logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
@@ -44,9 +42,9 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples
num_epochs = num_samples / dataset.num_samples
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -73,7 +71,7 @@ def log_train_info(logger, info, step, cfg, offline_buffer, is_offline):
logger.log_dict(info, step, mode="train")
def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
def log_eval_info(logger, info, step, cfg, dataset, is_offline):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
@@ -81,9 +79,9 @@ def log_eval_info(logger, info, step, cfg, offline_buffer, is_offline):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.policy.batch_size
avg_samples_per_ep = offline_buffer.num_samples / offline_buffer.num_episodes
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / offline_buffer.num_samples
num_epochs = num_samples / dataset.num_samples
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
@@ -124,30 +122,30 @@ def train(cfg: dict, out_dir=None, job_name=None):
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("make_dataset")
dataset = make_dataset(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
if cfg.policy.balanced_sampling:
logging.info("make online_buffer")
num_traj_per_batch = cfg.policy.batch_size
# if cfg.policy.balanced_sampling:
# logging.info("make online_buffer")
# num_traj_per_batch = cfg.policy.batch_size
online_sampler = PrioritizedSliceSampler(
max_capacity=100_000,
alpha=cfg.policy.per_alpha,
beta=cfg.policy.per_beta,
num_slices=num_traj_per_batch,
strict_length=True,
)
# online_sampler = PrioritizedSliceSampler(
# max_capacity=100_000,
# alpha=cfg.policy.per_alpha,
# beta=cfg.policy.per_beta,
# num_slices=num_traj_per_batch,
# strict_length=True,
# )
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
transform=offline_buffer.transform,
)
# online_buffer = TensorDictReplayBuffer(
# storage=LazyMemmapStorage(100_000),
# sampler=online_sampler,
# transform=dataset.transform,
# )
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer.transform)
env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(cfg)
@@ -155,8 +153,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
td_policy = TensorDictModule(policy, in_keys=["observation", "step_count"], out_keys=["action"])
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
@@ -165,8 +161,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{offline_buffer.num_samples=} ({format_big_number(offline_buffer.num_samples)})")
logging.info(f"{offline_buffer.num_episodes=}")
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -176,14 +172,15 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
env,
td_policy,
policy,
num_episodes=cfg.eval_episodes,
max_steps=cfg.env.episode_length,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_buffer, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logging.info("Resume training")
@@ -196,14 +193,29 @@ def train(cfg: dict, out_dir=None, job_name=None):
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
policy.train()
train_info = policy.update(offline_buffer, step)
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
@@ -211,7 +223,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
demo_buffer = offline_buffer if cfg.policy.balanced_sampling else None
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False
for env_step in range(cfg.online_steps):
@@ -221,7 +233,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length,
policy=td_policy,
policy=policy,
auto_cast_to_device=True,
)
@@ -242,7 +254,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
online_buffer.extend(rollout)
# online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
@@ -257,13 +269,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
for _ in range(cfg.policy.utd):
train_info = policy.update(
online_buffer,
# online_buffer,
step,
demo_buffer=demo_buffer,
)
if step % cfg.log_freq == 0:
train_info.update(rollout_info)
log_train_info(logger, train_info, step, cfg, offline_buffer, is_offline)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.

View File

@@ -10,7 +10,7 @@ from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
@@ -44,8 +44,8 @@ def visualize_dataset(cfg: dict, out_dir=None):
shuffle=False,
)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
logging.info("make_dataset")
dataset = make_dataset(
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
@@ -55,12 +55,12 @@ def visualize_dataset(cfg: dict, out_dir=None):
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
def render_dataset(dataset, out_dir, max_num_samples, fps):
out_dir = Path(out_dir)
video_paths = []
threads = []
@@ -69,17 +69,17 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = offline_buffer.sample(1)
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = offline_buffer._sampler._sample_list.numel()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = dataset._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
for im_key in offline_buffer.image_keys:
for im_key in dataset.image_keys:
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
# when first frame of episode, initialize frames dict
if im_key not in frames:
@@ -93,7 +93,7 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
frames[im_key].append(ep_td["next"][im_key])
out_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:
if len(dataset.image_keys) > 1:
camera = im_key[-1]
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else: