Merge remote-tracking branch 'upstream/main' into unify_policy_api

This commit is contained in:
Alexander Soare
2024-04-17 08:08:57 +01:00
53 changed files with 3184 additions and 1124 deletions

View File

@@ -41,7 +41,9 @@ import gymnasium as gym
import imageio
import numpy as np
import torch
from datasets import Dataset
from huggingface_hub import snapshot_download
from PIL import Image as PILImage
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
@@ -199,38 +201,48 @@ def eval_policy(
ep_dicts = []
num_episodes = dones.shape[0]
total_frames = 0
idx0 = idx1 = 0
data_ids_per_episode = {}
idx_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
# TODO(rcadene): We need to add a missing last frame which is the observation
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode": torch.tensor([ep_id] * num_frames),
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id, :num_frames]
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
total_frames += num_frames
idx1 += num_frames
data_ids_per_episode[ep_id] = torch.arange(idx0, idx1, 1)
idx0 = idx1
idx_from += num_frames
# similar logic is implemented in dataset preprocessing
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
if "image" not in key:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@@ -280,10 +292,7 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": {
"data_dict": data_dict,
"data_ids_per_episode": data_ids_per_episode,
},
"episodes": data_dict,
}
if max_episodes_rendered > 0:
info["videos"] = videos

View File

@@ -4,6 +4,8 @@ from pathlib import Path
import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils.logging import disable_progress_bar
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -128,29 +130,33 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(episodes, online_dataset, concat_dataset, sampler, pc_online_samples):
data_dict = episodes["data_dict"]
data_ids_per_episode = episodes["data_ids_per_episode"]
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
first_index = data_dict.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.data_ids_per_episode = data_ids_per_episode
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = max(online_dataset.data_ids_per_episode.keys()) + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
data_dict["episode"] += start_episode
data_dict["index"] += start_index
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bar() # map has a tqdm progress bar
data_dict = data_dict.map(shift_indices)
# extend online dataset
for key in data_dict:
# TODO(rcadene): avoid reallocating memory at every step by preallocating memory or changing our data structure
online_dataset.data_dict[key] = torch.cat([online_dataset.data_dict[key], data_dict[key]])
for ep_id in data_ids_per_episode:
online_dataset.data_ids_per_episode[ep_id + start_episode] = (
data_ids_per_episode[ep_id] + start_index
)
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
@@ -269,7 +275,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.data_ids_per_episode = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])

View File

@@ -62,12 +62,12 @@ def render_dataset(dataset, out_dir, max_num_episodes):
)
dl_iter = iter(dataloader)
num_episodes = len(dataset.data_ids_per_episode)
for ep_id in range(min(max_num_episodes, num_episodes)):
for ep_id in range(min(max_num_episodes, dataset.num_episodes)):
logging.info(f"Rendering episode {ep_id}")
frames = {}
for _ in dataset.data_ids_per_episode[ep_id]:
end_of_episode = False
while not end_of_episode:
item = next(dl_iter)
for im_key in dataset.image_keys:
@@ -77,6 +77,8 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render
frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys:
if len(dataset.image_keys) > 1: