Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -50,7 +50,12 @@ available_datasets = {
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
"pusht": ["pusht"],
|
||||
"xarm": ["xarm_lift_medium"],
|
||||
"xarm": [
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
],
|
||||
}
|
||||
|
||||
available_policies = [
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
@@ -27,7 +31,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
@@ -40,13 +44,10 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -54,7 +55,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
@@ -66,19 +67,11 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
@@ -52,32 +50,18 @@ def make_dataset(
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load stats if the file exists already or compute stats and save it
|
||||
if DATA_DIR is None:
|
||||
# TODO(rcadene): clean stats
|
||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||
else:
|
||||
precomputed_stats_path = DATA_DIR / cfg.dataset_id / "stats.pth"
|
||||
if precomputed_stats_path.exists():
|
||||
stats = torch.load(precomputed_stats_path)
|
||||
else:
|
||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||
# Create a dataset for stats computation.
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_stats(stats_dataset)
|
||||
precomputed_stats_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(stats, precomputed_stats_path)
|
||||
# load a first dataset to access precomputed stats
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
)
|
||||
stats = stats_dataset.stats
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
@@ -25,7 +29,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str = "pusht",
|
||||
version: str | None = "v1.0",
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
@@ -38,13 +42,10 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -52,7 +53,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
return len(self.episode_data_index["from"])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
@@ -64,19 +65,11 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -1,15 +1,121 @@
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image, load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
```
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
|
||||
>>> print(flatten_dict(dct))
|
||||
{"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
with channel first (c h w) of float32 type in range [0,1].
|
||||
"""
|
||||
for key in items_dict:
|
||||
first_item = items_dict[key][0]
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
||||
|
||||
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
|
||||
else:
|
||||
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
repo_id = f"lerobot/{dataset_id}"
|
||||
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float,
|
||||
) -> dict[torch.Tensor]:
|
||||
@@ -31,6 +137,8 @@ def load_previous_and_future_frames(
|
||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||
modality (e.g., "timestamp", "observation.image", "action").
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||
@@ -46,12 +154,14 @@ def load_previous_and_future_frames(
|
||||
issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_id_from = item["episode_data_index_from"].item()
|
||||
ep_data_id_to = item["episode_data_index_to"].item()
|
||||
ep_id = item["episode_index"].item()
|
||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = torch.stack(ep_timestamps)
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
@@ -82,39 +192,57 @@ def load_previous_and_future_frames(
|
||||
|
||||
# load frames modality
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
item[key] = torch.stack(item[key])
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
def get_stats_einops_patterns(hf_dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images of `hf_dataset` are in channel first format
|
||||
"""
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in hf_dataset.features.items():
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, Image):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
raise NotImplementedError("We need to set shuffle=True, but this violate an assert for now.")
|
||||
max_num_samples = len(hf_dataset)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
# pin_memory=cfg.device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
@@ -124,10 +252,24 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
@@ -153,6 +295,7 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
|
||||
@@ -1,25 +1,37 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
from lerobot.common.datasets.utils import (
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
)
|
||||
|
||||
|
||||
class XarmDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
|
||||
https://huggingface.co/datasets/lerobot/xarm_push_medium
|
||||
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
|
||||
"""
|
||||
|
||||
# Copied from lerobot/__init__.py
|
||||
available_datasets = ["xarm_lift_medium"]
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
"xarm_lift_medium_replay",
|
||||
"xarm_push_medium",
|
||||
"xarm_push_medium_replay",
|
||||
]
|
||||
fps = 15
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str = "xarm_lift_medium",
|
||||
version: str | None = "v1.0",
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
@@ -32,13 +44,10 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
# load data from hub or locally when root is provided
|
||||
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
|
||||
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
|
||||
self.stats = load_stats(dataset_id, version, root)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -46,7 +55,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
@@ -58,19 +67,11 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -39,4 +39,5 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
for _ in range(num_parallel_envs)
|
||||
]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -15,8 +15,19 @@ def preprocess_observation(observation, transform=None):
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
# convert to (b c h w) torch format
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w")
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
obs[imgkey] = img
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing requirement for "agent_pos"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import torch
|
||||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
@@ -12,40 +11,6 @@ def apply_inverse_transform(item, transform):
|
||||
return item
|
||||
|
||||
|
||||
class Prod(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: list[str], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def forward(self, item):
|
||||
for key in self.in_keys:
|
||||
if key not in item:
|
||||
continue
|
||||
self.original_dtypes[key] = item[key].dtype
|
||||
item[key] = item[key].type(torch.float32) * self.prod
|
||||
return item
|
||||
|
||||
def inverse_transform(self, item):
|
||||
for key in self.in_keys:
|
||||
if key not in item:
|
||||
continue
|
||||
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
||||
return item
|
||||
|
||||
# def transform_observation_spec(self, obs_spec):
|
||||
# for key in self.in_keys:
|
||||
# if obs_spec.get(key, None) is None:
|
||||
# continue
|
||||
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
# obs_spec[key].dtype = torch.float32
|
||||
# return obs_spec
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
invertible = True
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ from PIL import Image as PILImage
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
@@ -208,11 +209,12 @@ def eval_policy(
|
||||
max_rewards.extend(batch_max_reward.tolist())
|
||||
all_successes.extend(batch_success.tolist())
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
# similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`)
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
num_episodes = dones.shape[0]
|
||||
total_frames = 0
|
||||
idx_from = 0
|
||||
id_from = 0
|
||||
for ep_id in range(num_episodes):
|
||||
num_frames = done_indices[ep_id].item() + 1
|
||||
total_frames += num_frames
|
||||
@@ -222,19 +224,20 @@ def eval_policy(
|
||||
if return_episode_data:
|
||||
ep_dict = {
|
||||
"action": actions[ep_id, :num_frames],
|
||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"episode_index": torch.tensor([ep_id] * num_frames),
|
||||
"frame_index": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
idx_from += num_frames
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(id_from + num_frames)
|
||||
|
||||
id_from += num_frames
|
||||
|
||||
# similar logic is implemented in dataset preprocessing
|
||||
if return_episode_data:
|
||||
@@ -247,14 +250,29 @@ def eval_policy(
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
# c h w -> h w c
|
||||
img = PILImage.fromarray(x.permute(1, 2, 0).numpy())
|
||||
for img in ep_dict[key]:
|
||||
# sanity check that images are channel first
|
||||
c, h, w = img.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert img.dtype == torch.float32, f"expect torch.float32, but instead {img.dtype=}"
|
||||
assert img.max() <= 1, f"expect pixels lower than 1, but instead {img.max()=}"
|
||||
assert img.min() >= 0, f"expect pixels greater than 1, but instead {img.min()=}"
|
||||
|
||||
# from float32 in range [0,1] to uint8 in range [0,255]
|
||||
img *= 255
|
||||
img = img.type(torch.uint8)
|
||||
|
||||
# convert to channel last and numpy as expected by PIL
|
||||
img = PILImage.fromarray(img.permute(1, 2, 0).numpy())
|
||||
|
||||
data_dict[key].append(img)
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
@@ -307,7 +325,10 @@ def eval_policy(
|
||||
},
|
||||
}
|
||||
if return_episode_data:
|
||||
info["episodes"] = hf_dataset
|
||||
info["episodes"] = {
|
||||
"hf_dataset": hf_dataset,
|
||||
"episode_data_index": episode_data_index,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
return info
|
||||
|
||||
@@ -136,6 +136,7 @@ def add_episodes_inplace(
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
@@ -151,13 +152,15 @@ def add_episodes_inplace(
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
@@ -167,21 +170,22 @@ def add_episodes_inplace(
|
||||
online_dataset.hf_dataset = hf_dataset
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
|
||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_index"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_index_from"] += start_index
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices)
|
||||
enable_progress_bars()
|
||||
|
||||
episode_data_index["from"] += start_index
|
||||
episode_data_index["to"] += start_index
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
@@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
|
||||
@@ -22,11 +22,24 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
# Expects images in [0, 255].
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
# Expects images in [0, 1].
|
||||
frame = frames[0]
|
||||
if frame.ndim == 4:
|
||||
raise NotImplementedError("We currently dont support multiple timestamps.")
|
||||
c, h, w = frame.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {frame.shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert frame.dtype == torch.float32, f"expect torch.float32, but instead {frame.dtype=}"
|
||||
assert frame.max() <= 1, f"expect pixels lower than 1, but instead {frame.max()=}"
|
||||
assert frame.min() >= 0, f"expect pixels greater than 1, but instead {frame.min()=}"
|
||||
|
||||
# convert to channel last uint8 [0, 255]
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = (frames * 255).type(torch.uint8)
|
||||
imageio.mimsave(video_path, frames.numpy(), fps=fps)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
@@ -44,9 +57,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
return video_paths
|
||||
|
||||
|
||||
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
@@ -77,7 +91,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item() - 1
|
||||
end_of_episode = item["index"].item() == dataset.episode_data_index["to"][ep_id] - 1
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
|
||||
Reference in New Issue
Block a user