add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
208 lines
8.4 KiB
Python
208 lines
8.4 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import einops
|
|
import torch
|
|
import torchrl
|
|
import tqdm
|
|
from huggingface_hub import snapshot_download
|
|
from tensordict import TensorDict
|
|
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
|
from torchrl.data.replay_buffers.samplers import Sampler
|
|
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
|
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
|
from torchrl.envs.transforms.transforms import Compose
|
|
|
|
HF_USER = "lerobot"
|
|
|
|
|
|
class AbstractDataset(TensorDictReplayBuffer):
|
|
"""
|
|
AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning.
|
|
This class is designed to be subclassed by concrete implementations that specify particular types of datasets.
|
|
These implementations can vary based on the source of the data, the environment the data pertains to,
|
|
or the specific kind of data manipulation applied.
|
|
|
|
Note:
|
|
- `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational
|
|
functionality for storing and retrieving `TensorDict`-like data.
|
|
- `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported.
|
|
It is expected that these variants correspond to a HuggingFace dataset on the hub.
|
|
For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants:
|
|
- [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
|
- [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
|
- [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
|
- [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
|
- When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
|
1. set the required class attributes:
|
|
- for classes inheriting from `AbstractDataset`: `available_datasets`
|
|
- for classes inheriting from `AbstractEnv`: `name`, `available_tasks`
|
|
- for classes inheriting from `AbstractPolicy`: `name`
|
|
2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
3. update variables in `tests/test_available.py` by importing your new class
|
|
"""
|
|
|
|
available_datasets: list[str] | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
dataset_id: str,
|
|
version: str | None = None,
|
|
batch_size: int | None = None,
|
|
*,
|
|
shuffle: bool = True,
|
|
root: Path | None = None,
|
|
pin_memory: bool = False,
|
|
prefetch: int = None,
|
|
sampler: Sampler | None = None,
|
|
collate_fn: Callable | None = None,
|
|
writer: Writer | None = None,
|
|
transform: "torchrl.envs.Transform" = None,
|
|
):
|
|
assert (
|
|
self.available_datasets is not None
|
|
), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute."
|
|
assert (
|
|
dataset_id in self.available_datasets
|
|
), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}."
|
|
|
|
self.dataset_id = dataset_id
|
|
self.version = version
|
|
self.shuffle = shuffle
|
|
self.root = root
|
|
|
|
if self.root is not None and self.version is not None:
|
|
logging.warning(
|
|
f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})."
|
|
)
|
|
|
|
storage = self._download_or_load_dataset()
|
|
|
|
super().__init__(
|
|
storage=storage,
|
|
sampler=sampler,
|
|
writer=ImmutableDatasetWriter() if writer is None else writer,
|
|
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
|
pin_memory=pin_memory,
|
|
prefetch=prefetch,
|
|
batch_size=batch_size,
|
|
transform=transform,
|
|
)
|
|
|
|
@property
|
|
def stats_patterns(self) -> dict:
|
|
return {
|
|
("observation", "state"): "b c -> c",
|
|
("observation", "image"): "b c h w -> c 1 1",
|
|
("action",): "b c -> c",
|
|
}
|
|
|
|
@property
|
|
def image_keys(self) -> list:
|
|
return [("observation", "image")]
|
|
|
|
@property
|
|
def num_cameras(self) -> int:
|
|
return len(self.image_keys)
|
|
|
|
@property
|
|
def num_samples(self) -> int:
|
|
return len(self)
|
|
|
|
@property
|
|
def num_episodes(self) -> int:
|
|
return len(self._storage._storage["episode"].unique())
|
|
|
|
@property
|
|
def transform(self):
|
|
return self._transform
|
|
|
|
def set_transform(self, transform):
|
|
if not isinstance(transform, Compose):
|
|
# required since torchrl calls `len(self._transform)` downstream
|
|
if isinstance(transform, list):
|
|
self._transform = Compose(*transform)
|
|
else:
|
|
self._transform = Compose(transform)
|
|
else:
|
|
self._transform = transform
|
|
|
|
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
|
stats_path = self.data_dir / "stats.pth"
|
|
if stats_path.exists():
|
|
stats = torch.load(stats_path)
|
|
else:
|
|
logging.info(f"compute_stats and save to {stats_path}")
|
|
stats = self._compute_stats(num_batch, batch_size)
|
|
torch.save(stats, stats_path)
|
|
return stats
|
|
|
|
def _download_or_load_dataset(self) -> torch.StorageBase:
|
|
if self.root is None:
|
|
self.data_dir = Path(
|
|
snapshot_download(
|
|
repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version
|
|
)
|
|
)
|
|
else:
|
|
self.data_dir = self.root / self.dataset_id
|
|
return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
|
|
|
def _compute_stats(self, num_batch=100, batch_size=32):
|
|
rb = TensorDictReplayBuffer(
|
|
storage=self._storage,
|
|
batch_size=batch_size,
|
|
prefetch=True,
|
|
)
|
|
|
|
mean, std, max, min = {}, {}, {}, {}
|
|
|
|
# compute mean, min, max
|
|
for _ in tqdm.tqdm(range(num_batch)):
|
|
batch = rb.sample()
|
|
for key, pattern in self.stats_patterns.items():
|
|
batch[key] = batch[key].float()
|
|
if key not in mean:
|
|
# first batch initialize mean, min, max
|
|
mean[key] = einops.reduce(batch[key], pattern, "mean")
|
|
max[key] = einops.reduce(batch[key], pattern, "max")
|
|
min[key] = einops.reduce(batch[key], pattern, "min")
|
|
else:
|
|
mean[key] += einops.reduce(batch[key], pattern, "mean")
|
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
batch = rb.sample()
|
|
|
|
for key in self.stats_patterns:
|
|
mean[key] /= num_batch
|
|
|
|
# compute std, min, max
|
|
for _ in tqdm.tqdm(range(num_batch)):
|
|
batch = rb.sample()
|
|
for key, pattern in self.stats_patterns.items():
|
|
batch[key] = batch[key].float()
|
|
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
|
if key not in std:
|
|
# first batch initialize std
|
|
std[key] = (batch_mean - mean[key]) ** 2
|
|
else:
|
|
std[key] += (batch_mean - mean[key]) ** 2
|
|
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
|
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
|
|
|
for key in self.stats_patterns:
|
|
std[key] = torch.sqrt(std[key] / num_batch)
|
|
|
|
stats = TensorDict({}, batch_size=[])
|
|
for key in self.stats_patterns:
|
|
stats[(*key, "mean")] = mean[key]
|
|
stats[(*key, "std")] = std[key]
|
|
stats[(*key, "max")] = max[key]
|
|
stats[(*key, "min")] = min[key]
|
|
|
|
if key[0] == "observation":
|
|
# use same stats for the next observations
|
|
stats[("next", *key)] = stats[key]
|
|
return stats
|