Loads episode_data_index and stats during dataset __init__ (#85)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-23 14:13:25 +02:00
committed by GitHub
parent e2168163cd
commit 1030ea0070
89 changed files with 1008 additions and 432 deletions

View File

@@ -50,7 +50,12 @@ available_datasets = {
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"xarm": ["xarm_lift_medium"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
],
}
available_policies = [

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
@@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
@@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,12 +1,10 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
from lerobot.common.transforms import NormalizeTransform
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -52,32 +50,18 @@ def make_dataset(
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None:
# load stats if the file exists already or compute stats and save it
if DATA_DIR is None:
# TODO(rcadene): clean stats
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
else:
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
# Create a dataset for stats computation.
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(stats, precomputed_stats_path)
# load a first dataset to access precomputed stats
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
)
stats = stats_dataset.stats
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
NormalizeTransform(
stats,
in_keys=[

View File

@@ -1,9 +1,13 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
@@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.0",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
@@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,15 +1,121 @@
from copy import deepcopy
from math import ceil
from pathlib import Path
import datasets
import einops
import torch
import tqdm
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from PIL import Image as PILImage
from safetensors.torch import load_file
from torchvision import transforms
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
with channel first (c h w) of float32 type in range [0,1].
"""
for key in items_dict:
first_item = items_dict[key][0]
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
) -> dict[torch.Tensor]:
@@ -31,6 +137,8 @@ def load_previous_and_future_frames(
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
@@ -46,12 +154,14 @@ def load_previous_and_future_frames(
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_index_from"].item()
ep_data_id_to = item["episode_data_index_to"].item()
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
@@ -82,39 +192,57 @@ def load_previous_and_future_frames(
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images of `hf_dataset` are in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(dataset, batch_size=32, max_num_samples=None):
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
max_num_samples = len(hf_dataset)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=batch_size,
shuffle=False,
# pin_memory=cfg.device != "cpu",
drop_last=False,
)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@@ -124,10 +252,24 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
@@ -153,6 +295,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):

View File

@@ -1,25 +1,37 @@
from pathlib import Path
import torch
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = ["xarm_lift_medium"]
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "xarm_lift_medium",
version: str | None = "v1.0",
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
@@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.hf_dataset = self.hf_dataset.with_format("torch")
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
@@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_id"))
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
@@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -39,4 +39,5 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
for _ in range(num_parallel_envs)
]
)
return env

View File

@@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# convert to (b c h w) torch format
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w")
img = img.type(torch.float32)
img /= 255
obs[imgkey] = img
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"

View File

@@ -1,4 +1,3 @@
import torch
from torchvision.transforms.v2 import Compose, Transform
@@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
return item
class Prod(Transform):
invertible = True
def __init__(self, in_keys: list[str], prod: float):
super().__init__()
self.in_keys = in_keys
self.prod = prod
self.original_dtypes = {}
def forward(self, item):
for key in self.in_keys:
if key not in item:
continue
self.original_dtypes[key] = item[key].dtype
item[key] = item[key].type(torch.float32) * self.prod
return item
def inverse_transform(self, item):
for key in self.in_keys:
if key not in item:
continue
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
return item
# def transform_observation_spec(self, obs_spec):
# for key in self.in_keys:
# if obs_spec.get(key, None) is None:
# continue
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
# obs_spec[key].dtype = torch.float32
# return obs_spec
class NormalizeTransform(Transform):
invertible = True

View File

@@ -47,6 +47,7 @@ from PIL import Image as PILImage
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import hf_transform_to_torch
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.logger import log_output_dir
@@ -208,11 +209,12 @@ def eval_policy(
max_rewards.extend(batch_max_reward.tolist())
all_successes.extend(batch_success.tolist())
# similar logic is implemented in dataset preprocessing
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
num_episodes = dones.shape[0]
total_frames = 0
idx_from = 0
id_from = 0
for ep_id in range(num_episodes):
num_frames = done_indices[ep_id].item() + 1
total_frames += num_frames
@@ -222,19 +224,20 @@ def eval_policy(
if return_episode_data:
ep_dict = {
"action": actions[ep_id, :num_frames],
"episode_id": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"episode_index": torch.tensor([ep_id] * num_frames),
"frame_index": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / fps,
"next.done": dones[ep_id, :num_frames],
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
}
for key in observations:
ep_dict[key] = observations[key][ep_id][:num_frames]
ep_dicts.append(ep_dict)
idx_from += num_frames
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(id_from + num_frames)
id_from += num_frames
# similar logic is implemented in dataset preprocessing
if return_episode_data:
@@ -247,14 +250,29 @@ def eval_policy(
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
# c h w -> h w c
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
for img in ep_dict[key]:
# sanity check that images are channel first
c, h, w = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# sanity check that images are float32 in range [0,1]
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
# from float32 in range [0,1] to uint8 in range [0,255]
img *= 255
img = img.type(torch.uint8)
# convert to channel last and numpy as expected by PIL
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
data_dict[key].append(img)
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
hf_dataset = Dataset.from_dict(data_dict)
hf_dataset.set_transform(hf_transform_to_torch)
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@@ -307,7 +325,10 @@ def eval_policy(
},
}
if return_episode_data:
info["episodes"] = hf_dataset
info["episodes"] = {
"hf_dataset": hf_dataset,
"episode_data_index": episode_data_index,
}
if max_episodes_rendered > 0:
info["videos"] = videos
return info

View File

@@ -136,6 +136,7 @@ def add_episodes_inplace(
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
@@ -151,13 +152,15 @@ def add_episodes_inplace(
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
@@ -167,21 +170,22 @@ def add_episodes_inplace(
online_dataset.hf_dataset = hf_dataset
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
@@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):

View File

@@ -22,11 +22,24 @@ def visualize_dataset_cli(cfg: dict):
def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 255].
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
# Expects images in [0, 1].
frame = frames[0]
if frame.ndim == 4:
raise NotImplementedError("We currently dont support multiple timestamps.")
c, h, w = frame.shape
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
# sanity check that images are float32 in range [0,1]
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
# convert to channel last uint8 [0, 255]
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = (frames * 255).type(torch.uint8)
imageio.mimsave(video_path, frames.numpy(), fps=fps)
def visualize_dataset(cfg: dict, out_dir=None):
@@ -44,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
for video_path in video_paths:
logging.info(video_path)
return video_paths
def render_dataset(dataset, out_dir, max_num_episodes):
@@ -77,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
# add current frame to list of frames to render
frames[im_key].append(item[im_key])
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
out_dir.mkdir(parents=True, exist_ok=True)
for im_key in dataset.image_keys: