Merge remote-tracking branch 'upstream/main' into refactor_dp

This commit is contained in:
Alexander Soare
2024-04-11 17:52:10 +01:00
29 changed files with 545 additions and 603 deletions

View File

@@ -32,6 +32,7 @@ import json
import logging
import threading
import time
from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
@@ -56,15 +57,15 @@ def write_video(video_path, stacked_frames, fps):
def eval_policy(
env: gym.vector.VectorEnv,
policy,
save_video: bool = False,
policy: torch.nn.Module,
max_episodes_rendered: int = 0,
video_dir: Path = None,
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
fps: int = 15,
return_first_video: bool = False,
transform: callable = None,
seed=None,
):
fps = env.unwrapped.metadata["render_fps"]
if policy is not None:
policy.eval()
device = "cpu" if policy is None else next(policy.parameters()).device
@@ -83,14 +84,11 @@ def eval_policy(
# needed as I'm currently taking a ceil.
ep_frames = []
def maybe_render_frame(env):
if save_video: # noqa: B023
if return_first_video:
visu = env.envs[0].render()
visu = visu[None, ...] # add batch dim
else:
visu = np.stack([env.render() for env in env.envs])
ep_frames.append(visu) # noqa: B023
def render_frame(env):
# noqa: B023
eps_rendered = min(max_episodes_rendered, len(env.envs))
visu = np.stack([env.envs[i].render() for i in range(eps_rendered)])
ep_frames.append(visu) # noqa: B023
for _ in range(num_episodes):
seeds.append("TODO")
@@ -104,8 +102,14 @@ def eval_policy(
# reset the environment
observation, info = env.reset(seed=seed)
maybe_render_frame(env)
if max_episodes_rendered > 0:
render_frame(env)
observations = []
actions = []
# episode
# frame_id
# timestamp
rewards = []
successes = []
dones = []
@@ -113,8 +117,13 @@ def eval_policy(
done = torch.tensor([False for _ in env.envs])
step = 0
while not done.all():
# format from env keys to lerobot keys
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
# apply transform to normalize the observations
observation = preprocess_observation(observation, transform)
for key in observation:
observation[key] = torch.stack([transform({key: item})[key] for item in observation[key]])
# send observation to device/gpu
observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
@@ -126,11 +135,13 @@ def eval_policy(
# apply inverse transform to unnormalize the action
action = postprocess_action(action, transform)
# apply the next
# apply the next action
observation, reward, terminated, truncated, info = env.step(action)
maybe_render_frame(env)
if max_episodes_rendered > 0:
render_frame(env)
# TODO(rcadene): implement a wrapper over env to return torch tensors in float32 (and cuda?)
action = torch.from_numpy(action)
reward = torch.from_numpy(reward)
terminated = torch.from_numpy(terminated)
truncated = torch.from_numpy(truncated)
@@ -147,12 +158,24 @@ def eval_policy(
success = [False for _ in env.envs]
success = torch.tensor(success)
actions.append(action)
rewards.append(reward)
dones.append(done)
successes.append(success)
step += 1
env.close()
# add the last observation when the env is done
observation = preprocess_observation(observation)
observations.append(deepcopy(observation))
new_obses = {}
for key in observations[0].keys(): # noqa: SIM118
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
observations = new_obses
actions = torch.stack(actions, dim=1)
rewards = torch.stack(rewards, dim=1)
successes = torch.stack(successes, dim=1)
dones = torch.stack(dones, dim=1)
@@ -172,29 +195,61 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
env.close()
# similar logic is implemented in dataset preprocessing
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dicts.append(ep_dict)
if save_video or return_first_video:
total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
# similar logic is implemented in dataset preprocessing
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(0, total_frames, 1)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
if save_video:
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
for stacked_frames, done_index in zip(
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
):
if episode_counter >= num_episodes:
continue
video_dir.mkdir(parents=True, exist_ok=True)
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
thread = threading.Thread(
target=write_video,
args=(str(video_path), stacked_frames[:done_index], fps),
)
thread.start()
threads.append(thread)
episode_counter += 1
if return_first_video:
first_video = batch_stacked_frames[0].transpose(0, 3, 1, 2)
videos = einops.rearrange(batch_stacked_frames, "b t h w c -> b t c h w")
for thread in threads:
thread.join()
@@ -225,9 +280,13 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
}
if return_first_video:
return info, first_video
if max_episodes_rendered > 0:
info["videos"] = videos
return info
@@ -253,16 +312,14 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("Making environment.")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
# when policy is None, rollout a random policy
policy = make_policy(cfg) if cfg.policy.pretrained_model_path else None
logging.info("Making policy.")
policy = make_policy(cfg)
info = eval_policy(
env,
policy=policy,
save_video=True,
policy,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
fps=cfg.env.fps,
# TODO(rcadene): what should we do with the transform?
transform=transform,
seed=cfg.seed,
)
@@ -270,6 +327,9 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
# remove pytorch tensors which are not serializable to save the evaluation results only
del info["episodes"]
del info["videos"]
json.dump(info, f, indent=2)
logging.info("End of eval")

View File

@@ -1,8 +1,8 @@
import logging
from copy import deepcopy
from pathlib import Path
import hydra
import numpy as np
import torch
from lerobot.common.datasets.factory import make_dataset
@@ -108,6 +108,64 @@ def log_eval_info(logger, info, step, cfg, dataset, is_offline):
logger.log_dict(info, step, mode="eval")
def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
"""
Calculate the sampling weight to be assigned to samples so that a specified percentage of the batch comes from online dataset (on average).
Parameters:
- n_off (int): Number of offline samples, each with a sampling weight of 1.
- n_on (int): Number of online samples.
- pc_on (float): Desired percentage of online samples in decimal form (e.g., 50% as 0.5).
The total weight of offline samples is n_off * 1.0.
The total weight of offline samples is n_on * w.
The total combined weight of all samples is n_off + n_on * w.
The fraction of the weight that is online is n_on * w / (n_off + n_on * w).
We want this fraction to equal pc_on, so we set up the equation n_on * w / (n_off + n_on * w) = pc_on.
The solution is w = - (n_off * pc_on) / (n_on * (pc_on - 1))
"""
assert 0.0 <= pc_on <= 1.0
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
data_dict = episodes["data_dict"]
data_ids_per_episode = episodes["data_ids_per_episode"]
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.data_ids_per_episode = data_ids_per_episode
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
data_dict["episode"] += start_episode
data_dict["index"] += start_index
# extend online dataset
for key in data_dict:
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
for ep_id in data_ids_per_episode:
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
data_ids_per_episode[ep_id] + start_index
)
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# update the sampling weights for each frame so that online frames get sampled a certain percentage of times
len_online = len(online_dataset)
len_offline = len(concat_dataset) - len_online
weight_offline = 1.0
weight_online = calculate_online_sample_weight(len_offline, len_online, pc_online_samples)
sampler.weights = torch.tensor([weight_offline] * len_offline + [weight_online] * len(online_dataset))
# update the total number of samples used during sampling
sampler.num_samples = len(concat_dataset)
def train(cfg: dict, out_dir=None, job_name=None):
if out_dir is None:
raise NotImplementedError()
@@ -126,26 +184,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
set_global_seed(cfg.seed)
logging.info("make_dataset")
dataset = make_dataset(cfg)
# TODO(rcadene): move balanced_sampling, per_alpha, per_beta outside policy
# if cfg.policy.balanced_sampling:
# logging.info("make online_buffer")
# num_traj_per_batch = cfg.policy.batch_size
# online_sampler = PrioritizedSliceSampler(
# max_capacity=100_000,
# alpha=cfg.policy.per_alpha,
# beta=cfg.policy.per_beta,
# num_slices=num_traj_per_batch,
# strict_length=True,
# )
# online_buffer = TensorDictReplayBuffer(
# storage=LazyMemmapStorage(100_000),
# sampler=online_sampler,
# transform=dataset.transform,
# )
offline_dataset = make_dataset(cfg)
logging.info("make_env")
env = make_env(cfg, num_parallel_envs=cfg.eval_episodes)
@@ -163,9 +202,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
logging.info(f"{cfg.online_steps=}")
logging.info(f"{cfg.env.action_repeat=}")
logging.info(f"{dataset.num_samples=} ({format_big_number(dataset.num_samples)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
@@ -173,18 +211,17 @@ def train(cfg: dict, out_dir=None, job_name=None):
def _maybe_eval_and_maybe_save(step):
if step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info, first_video = eval_policy(
eval_info = eval_policy(
env,
policy,
return_first_video=True,
video_dir=Path(out_dir) / "eval",
save_video=True,
transform=dataset.transform,
max_episodes_rendered=4,
transform=offline_dataset.transform,
seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, dataset, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(first_video, step, mode="eval")
logger.log_video(eval_info["videos"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_model and step % cfg.save_freq == 0:
@@ -192,18 +229,19 @@ def train(cfg: dict, out_dir=None, job_name=None):
logger.save_model(policy, identifier=step)
logging.info("Resume training")
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
dataset,
offline_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
shuffle=True,
pin_memory=cfg.device != "cpu",
drop_last=True,
drop_last=False,
)
dl_iter = cycle(dataloader)
step = 0 # number of policy update (forward + backward + optim)
is_offline = True
for offline_step in range(cfg.offline_steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
@@ -217,7 +255,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
@@ -225,61 +263,60 @@ def train(cfg: dict, out_dir=None, job_name=None):
step += 1
raise NotImplementedError()
# create an env dedicated to online episodes collection from policy rollout
rollout_env = make_env(cfg, num_parallel_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.data_ids_per_episode = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.policy.batch_size,
sampler=sampler,
pin_memory=cfg.device != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
demo_buffer = dataset if cfg.policy.balanced_sampling else None
online_step = 0
is_offline = False
for env_step in range(cfg.online_steps):
if env_step == 0:
logging.info("Start online training by interacting with environment")
# TODO: add configurable number of rollout? (default=1)
with torch.no_grad():
rollout = env.rollout(
max_steps=cfg.env.episode_length,
policy=policy,
auto_cast_to_device=True,
eval_info = eval_policy(
rollout_env,
policy,
transform=offline_dataset.transform,
seed=cfg.seed,
)
assert (
len(rollout.batch_size) == 2
), "2 dimensions expected: number of env in parallel x max number of steps during rollout"
num_parallel_env = rollout.batch_size[0]
if num_parallel_env != 1:
# TODO(rcadene): when num_parallel_env > 1, rollout["episode"] needs to be properly set and we need to add tests
raise NotImplementedError()
num_max_steps = rollout.batch_size[1]
assert num_max_steps <= cfg.env.episode_length
# reshape to have a list of steps to insert into online_buffer
rollout = rollout.reshape(num_parallel_env * num_max_steps)
# set same episode index for all time steps contained in this rollout
rollout["episode"] = torch.tensor([env_step] * len(rollout), dtype=torch.int)
# online_buffer.extend(rollout)
ep_sum_reward = rollout["next", "reward"].sum()
ep_max_reward = rollout["next", "reward"].max()
ep_success = rollout["next", "success"].any()
rollout_info = {
"avg_sum_reward": np.nanmean(ep_sum_reward),
"avg_max_reward": np.nanmean(ep_max_reward),
"pc_success": np.nanmean(ep_success) * 100,
"env_step": env_step,
"ep_length": len(rollout),
}
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
)
for _ in range(cfg.policy.utd):
train_info = policy.update(
# online_buffer,
step,
demo_buffer=demo_buffer,
)
policy.train()
batch = next(dl_iter)
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy(batch, step)
if step % cfg.log_freq == 0:
train_info.update(rollout_info)
log_train_info(logger, train_info, step, cfg, dataset, is_offline)
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.

View File

@@ -6,9 +6,6 @@ import einops
import hydra
import imageio
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
@@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None):
init_logging()
log_output_dir(out_dir)
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False,
)
logging.info("make_dataset")
dataset = make_dataset(
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
@@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None):
logging.info(video_path)
def render_dataset(dataset, out_dir, max_num_samples, fps):
def render_dataset(dataset, out_dir, max_num_episodes):
out_dir = Path(out_dir)
video_paths = []
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = dataset._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=1,
shuffle=False,
)
dl_iter = iter(dataloader)
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
num_episodes = len(dataset.data_ids_per_episode)
for ep_id in range(min(max_num_episodes, num_episodes)):
logging.info(f"Rendering episode {ep_id}")
for im_key in dataset.image_keys:
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
frames = {}
for _ in dataset.data_ids_per_episode[ep_id]:
item = next(dl_iter)
for im_key in dataset.image_keys:
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
frames[im_key].append(ep_td[im_key])
frames[im_key].append(item[im_key])
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:
if len(dataset.image_keys) > 1:
im_name = im_key.replace("observation.images.", "")
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
else:
# When episode has no more frame in its list of observation,
# one frame still remains. It is the result of the last action taken.
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
video_path = out_dir / f"episode_{ep_id}.mp4"
video_paths.append(video_path)
out_dir.mkdir(parents=True, exist_ok=True)
if len(dataset.image_keys) > 1:
camera = im_key[-1]
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
if num_frames_left == 0:
logging.info("Ran out of frames")
break
if current_ep_idx == NUM_EPISODES_TO_RENDER:
break
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], dataset.fps),
)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()