import logging from copy import deepcopy from math import ceil from pathlib import Path from typing import Callable import einops import torch import torchrl import tqdm from huggingface_hub import snapshot_download from tensordict import TensorDict from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer from torchrl.data.replay_buffers.samplers import Sampler, SamplerWithoutReplacement from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer from torchrl.envs.transforms.transforms import Compose HF_USER = "lerobot" class AbstractDataset(TensorDictReplayBuffer): """ AbstractDataset represents a dataset in the context of imitation learning or reinforcement learning. This class is designed to be subclassed by concrete implementations that specify particular types of datasets. These implementations can vary based on the source of the data, the environment the data pertains to, or the specific kind of data manipulation applied. Note: - `TensorDictReplayBuffer` is the base class from which `AbstractDataset` inherits. It provides the foundational functionality for storing and retrieving `TensorDict`-like data. - `available_datasets` should be overridden by concrete subclasses to list the specific dataset variants supported. It is expected that these variants correspond to a HuggingFace dataset on the hub. For instance, the `AlohaDataset` which inherites from `AbstractDataset` has 4 available dataset variants: - [aloha_sim_transfer_cube_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted) - [aloha_sim_insertion_scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted) - [aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) - [aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) - When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: 1. set the required class attributes: - for classes inheriting from `AbstractDataset`: `available_datasets` - for classes inheriting from `AbstractEnv`: `name`, `available_tasks` - for classes inheriting from `AbstractPolicy`: `name` 2. update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) 3. update variables in `tests/test_available.py` by importing your new class """ available_datasets: list[str] | None = None def __init__( self, dataset_id: str, version: str | None = None, batch_size: int | None = None, *, shuffle: bool = True, root: Path | None = None, pin_memory: bool = False, prefetch: int = None, sampler: Sampler | None = None, collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, ): assert ( self.available_datasets is not None ), "Subclasses of `AbstractDataset` should set the `available_datasets` class attribute." assert ( dataset_id in self.available_datasets ), f"The provided dataset ({dataset_id}) is not on the list of available datasets {self.available_datasets}." self.dataset_id = dataset_id self.version = version self.shuffle = shuffle self.root = root if root is None else Path(root) if self.root is not None and self.version is not None: logging.warning( f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) storage = self._download_or_load_dataset() super().__init__( storage=storage, sampler=sampler, writer=ImmutableDatasetWriter() if writer is None else writer, collate_fn=_collate_id if collate_fn is None else collate_fn, pin_memory=pin_memory, prefetch=prefetch, batch_size=batch_size, transform=transform, ) @property def stats_patterns(self) -> dict: return { ("observation", "state"): "b c -> c", ("observation", "image"): "b c h w -> c 1 1", ("action",): "b c -> c", } @property def image_keys(self) -> list: return [("observation", "image")] @property def num_cameras(self) -> int: return len(self.image_keys) @property def num_samples(self) -> int: return len(self) @property def num_episodes(self) -> int: return len(self._storage._storage["episode"].unique()) @property def transform(self): return self._transform def set_transform(self, transform): if not isinstance(transform, Compose): # required since torchrl calls `len(self._transform)` downstream if isinstance(transform, list): self._transform = Compose(*transform) else: self._transform = Compose(transform) else: self._transform = transform def compute_or_load_stats(self, batch_size: int = 32) -> TensorDict: stats_path = self.data_dir / "stats.pth" if stats_path.exists(): stats = torch.load(stats_path) else: logging.info(f"compute_stats and save to {stats_path}") stats = self._compute_stats(batch_size) torch.save(stats, stats_path) return stats def _download_or_load_dataset(self) -> torch.StorageBase: if self.root is None: self.data_dir = Path( snapshot_download( repo_id=f"{HF_USER}/{self.dataset_id}", repo_type="dataset", revision=self.version ) ) else: self.data_dir = self.root / self.dataset_id return TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer")) def _compute_stats(self, batch_size: int = 32): """Compute dataset statistics including minimum, maximum, mean, and standard deviation.""" rb = TensorDictReplayBuffer( storage=self._storage, batch_size=32, prefetch=True, # Note: Due to be refactored soon. The point is that we should go through the whole dataset. sampler=SamplerWithoutReplacement(drop_last=False, shuffle=False), ) # mean and std will be computed incrementally while max and min will track the running value. mean, std, max, min = {}, {}, {}, {} for key in self.stats_patterns: mean[key] = torch.tensor(0.0).float() std[key] = torch.tensor(0.0).float() max[key] = torch.tensor(-float("inf")).float() min[key] = torch.tensor(float("inf")).float() # compute mean, min, max # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get # surprises when rerunning the sampler. first_batch = None for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() if first_batch is None: first_batch = deepcopy(batch) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() # Sum over batch then divide by total number of samples. mean[key] = mean[key] + einops.reduce(batch[key], pattern, "mean") * batch.batch_size[0] max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) for key in self.stats_patterns: mean[key] = mean[key] / len(rb) # Compute std first_batch_ = None for _ in tqdm.tqdm(range(ceil(len(rb) / batch_size))): batch = rb.sample() # Sanity check to make sure the batches are still in the same order as before. if first_batch_ is None: first_batch_ = deepcopy(batch) for key in self.stats_patterns: assert torch.equal(first_batch_[key], first_batch[key]) for key, pattern in self.stats_patterns.items(): batch[key] = batch[key].float() # Sum over batch then divide by total number of samples. std[key] = ( std[key] + einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") * batch.batch_size[0] ) for key in self.stats_patterns: std[key] = torch.sqrt(std[key] / len(rb)) stats = TensorDict({}, batch_size=[]) for key in self.stats_patterns: stats[(*key, "mean")] = mean[key] stats[(*key, "std")] = std[key] stats[(*key, "max")] = max[key] stats[(*key, "min")] = min[key] if key[0] == "observation": # use same stats for the next observations stats[("next", *key)] = stats[key] return stats