From fccbce9ef9315d9c21517d4bb90ec7b3dc937147 Mon Sep 17 00:00:00 2001 From: fracapuano Date: Tue, 27 May 2025 11:43:55 +0200 Subject: [PATCH] add: randomized streaming dataset --- lerobot/common/datasets/lerobot_dataset.py | 4 + lerobot/common/datasets/streaming_dataset.py | 252 +++++++++++++++++++ lerobot/common/datasets/utils.py | 25 +- lerobot/common/datasets/video_utils.py | 71 +++++- 4 files changed, 339 insertions(+), 13 deletions(-) create mode 100644 lerobot/common/datasets/streaming_dataset.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index ab04e61e0..404fa40a0 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -151,6 +151,10 @@ class LeRobotDatasetMetadata: fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) return Path(fpath) + @property + def url_root(self) -> str: + return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main" + @property def data_path(self) -> str: """Formattable string for the parquet files.""" diff --git a/lerobot/common/datasets/streaming_dataset.py b/lerobot/common/datasets/streaming_dataset.py new file mode 100644 index 000000000..d671fee21 --- /dev/null +++ b/lerobot/common/datasets/streaming_dataset.py @@ -0,0 +1,252 @@ +import random +from pathlib import Path +from typing import Callable, Dict, Generator, Iterator + +import datasets +import numpy as np +import torch +from datasets import load_dataset +from line_profiler import profile + +from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.common.datasets.utils import ( + check_version_compatibility, + item_to_torch, +) +from lerobot.common.datasets.video_utils import ( + VideoDecoderCache, + decode_video_frames_torchcodec, + get_safe_default_codec, +) + + +class StreamingLeRobotDataset(torch.utils.data.IterableDataset): + """LeRobotDataset with streaming capabilities. + + This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed + rather than loaded entirely into memory. This is especially useful for large datasets that may + not fit in memory or when you want to quickly explore a dataset without downloading it completely. + + The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent + items, allowing us to access previous frames for delta timestamps without loading the entire + dataset into memory. + + Example: + Basic usage: + ```python + from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset + + # Create a streaming dataset with delta timestamps + delta_timestamps = { + "observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current + "action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future + } + + dataset = StreamingLeRobotDataset( + repo_id="your-dataset-repo-id", + delta_timestamps=delta_timestamps, + streaming=True, + buffer_size=1000, + ) + + # Iterate over the dataset + for i, item in enumerate(dataset): + print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}") + # item will contain stacked frames according to delta_timestamps + if i >= 10: + break + ``` + """ + + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + episodes: list[int] | None = None, + image_transforms: Callable | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + video_backend: str | None = "torchcodec", + streaming: bool = True, + buffer_size: int = 1000, + max_num_shards: int = 16, + seed: int = 42, + rng: np.random.Generator | None = None, + ): + """Initialize a StreamingLeRobotDataset. + + Args: + repo_id (str): This is the repo id that will be used to fetch the dataset. + root (Path | None, optional): Local directory to use for downloading/writing files. + episodes (list[int] | None, optional): If specified, this will only load episodes specified by + their episode_index in this list. + image_transforms (Callable | None, optional): Transform to apply to image data. + tolerance_s (float, optional): Tolerance in seconds for timestamp matching. + revision (str, optional): Git revision id (branch name, tag, or commit hash). + force_cache_sync (bool, optional): Flag to sync and refresh local files first. + video_backend (str | None, optional): Video backend to use for decoding videos. Uses "torchcodec" by default. + streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True. + buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000. + max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16. + seed (int, optional): Reproducibility random seed. + rng (np.random.Generator | None, optional): Random number generator. + """ + super().__init__() + self.repo_id = repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self.image_transforms = image_transforms + self.episodes = episodes + self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION + self.video_backend = video_backend if video_backend else get_safe_default_codec() + self.seed = seed + self.rng = rng if rng is not None else np.random.default_rng(seed) + + self.streaming = streaming + self.buffer_size = buffer_size + + # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) + self.video_decoder_cache = VideoDecoderCache() + + # Unused attributes + self.image_writer = None + self.episode_buffer = None + + self.root.mkdir(exist_ok=True, parents=True) + + # Load metadata + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + # Check version + check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) + + self.hf_dataset = self.load_hf_dataset() + self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) + + @property + def fps(self): + return self.meta.fps + + @staticmethod + def _iter_random_indices( + rng: np.random.Generator, buffer_size: int, random_batch_size=1000 + ) -> Iterator[int]: + while True: + yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size)) + + @staticmethod + def _infinite_generator_over_elements(elements: list[int]) -> Iterator[int]: + return (random.choice(list(elements)) for _ in iter(int, 1)) + + def load_hf_dataset(self) -> datasets.IterableDataset: + dataset = load_dataset(self.repo_id, split="train", streaming=self.streaming) + self.streaming_from_local = False + + # TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub + return dataset + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size) + + # This buffer is populated while iterating on the dataset's shards + frames_buffer = [] + idx_to_iterable_dataset = { + idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx)) + for idx in range(self.num_shards) + } + + try: + while available_shards := list(idx_to_iterable_dataset.keys()): + shard_key = next(self._infinite_generator_over_elements(available_shards)) + dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on + for frame in self.make_frame(dataset): + if len(frames_buffer) == self.buffer_size: + i = next(buffer_indices_generator) + yield frames_buffer[i] + frames_buffer[i] = frame + else: + frames_buffer.append(frame) + break # random shard sampled, switch shard + + except ( + RuntimeError, + StopIteration, + ): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7 + # Remove exhausted shard + del idx_to_iterable_dataset[shard_key] + + # Once shards are all exhausted, shuffle the buffer and yield the remaining frames + self.rng.shuffle(frames_buffer) + yield from frames_buffer + + def _make_iterable_dataset(self, dataset: datasets.Dataset) -> Iterator: + return iter(dataset) + + @profile + def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator: + """Makes a frame starting from a dataset iterator""" + item = next(dataset_iterator) + item = item_to_torch(item) + + # Get episode index from the item + ep_idx = item["episode_index"] + + # Load video frames, when needed + if len(self.meta.video_keys) > 0: + current_ts = item["timestamp"] + query_timestamps = self._get_query_timestamps(current_ts, None) + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + + # Add task as a string + task_idx = item["task_index"] + item["task"] = self.meta.tasks.iloc[task_idx].name + + yield item + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + for key in self.meta.video_keys: + if query_indices is not None and key in query_indices: + timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] + query_timestamps[key] = torch.stack(timestamps).tolist() + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + item = {} + for vid_key, query_ts in query_timestamps.items(): + root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root + video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}" + frames = decode_video_frames_torchcodec( + video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache + ) + item[vid_key] = frames.squeeze(0) + + return item + + +# Example usage +if __name__ == "__main__": + repo_id = "lerobot/aloha_mobile_cabinet" + dataset = StreamingLeRobotDataset(repo_id) + + for i, frame in enumerate(dataset): + print(frame) + + if i > 10: # only stream first 10 frames + break diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index bdf3eba97..96eeebc56 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -24,7 +24,7 @@ from collections.abc import Iterator from pathlib import Path from pprint import pformat from types import SimpleNamespace -from typing import Any +from typing import Any, TypeVar import datasets import numpy as np @@ -85,6 +85,8 @@ DEFAULT_FEATURES = { "task_index": {"dtype": "int64", "shape": (1,), "names": None}, } +T = TypeVar("T") + def get_parquet_file_size_in_mb(parquet_path): metadata = pq.read_metadata(parquet_path) @@ -898,3 +900,24 @@ def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) else: df.to_parquet(path) + + +def item_to_torch(item: dict) -> dict: + """Convert all items in a dictionary to PyTorch tensors where appropriate. + + This function is used to convert an item from a streaming dataset to PyTorch tensors. + + Args: + item (dict): Dictionary of items from a dataset. + + Returns: + dict: Dictionary with all tensor-like items converted to torch.Tensor. + """ + import numpy as np + import torch + + for key, val in item.items(): + if isinstance(val, (np.ndarray, list)) and key not in ["task"]: + # Convert numpy arrays and lists to torch tensors + item[key] = torch.tensor(val) + return item diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index b0f6c15c2..fb58d12ea 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -21,12 +21,15 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import Any, ClassVar +from threading import Lock +from typing import Any, ClassVar, Dict, Literal, Optional +import fsspec import pyarrow as pa import torch import torchvision from datasets.features.features import register_feature +from line_profiler import profile from PIL import Image @@ -74,7 +77,7 @@ def decode_video_frames_torchvision( video_path: Path | str, timestamps: list[float], tolerance_s: float, - backend: str = "pyav", + backend: Literal["pyav", "video_reader"] = "pyav", log_loaded_timestamps: bool = False, ) -> torch.Tensor: """Loads frames associated to the requested timestamps of a video @@ -169,15 +172,62 @@ def decode_video_frames_torchvision( return closest_frames +class VideoDecoderCache: + """Thread-safe cache for video decoders to avoid expensive re-initialization.""" + + def __init__(self): + self._cache: Dict[str, Any] = {} + self._lock = Lock() + + def get_decoder(self, video_path: str): + """Get a cached decoder or create a new one.""" + if importlib.util.find_spec("torchcodec"): + from torchcodec.decoders import VideoDecoder + else: + raise ImportError("torchcodec is required but not available.") + + video_path = str(video_path) + + with self._lock: + if video_path not in self._cache: + file_handle = fsspec.open(video_path, client_kwargs={"trust_env": True}).__enter__() + decoder = VideoDecoder(file_handle, seek_mode="approximate") + self._cache[video_path] = decoder + + return self._cache[video_path] + + def clear(self): + """Clear the cache.""" + with self._lock: + self._cache.clear() + + def size(self) -> int: + """Return the number of cached decoders.""" + with self._lock: + return len(self._cache) + + +# Global instance +_default_decoder_cache = VideoDecoderCache() + + +@profile def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], tolerance_s: float, - device: str = "cpu", log_loaded_timestamps: bool = False, + decoder_cache: Optional[VideoDecoderCache] = None, ) -> torch.Tensor: """Loads frames associated with the requested timestamps of a video using torchcodec. + Args: + video_path: Path to the video file. + timestamps: List of timestamps to extract frames. + tolerance_s: Allowed deviation in seconds for frame retrieval. + log_loaded_timestamps: Whether to log loaded timestamps. + decoder_cache: Optional decoder cache instance. Uses default if None. + Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. Note: Video benefits from inter-frame compression. Instead of storing every frame individually, @@ -186,23 +236,20 @@ def decode_video_frames_torchcodec( and all subsequent frames until reaching the requested frame. The number of key frames in a video can be adjusted during encoding to take into account decoding time and video size in bytes. """ + if decoder_cache is None: + decoder_cache = _default_decoder_cache - if importlib.util.find_spec("torchcodec"): - from torchcodec.decoders import VideoDecoder - else: - raise ImportError("torchcodec is required but not available.") + # Use cached decoder instead of creating new one each time + decoder = decoder_cache.get_decoder(str(video_path)) - # initialize video decoder - decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") - loaded_frames = [] loaded_ts = [] + loaded_frames = [] + # get metadata for frame information metadata = decoder.metadata average_fps = metadata.average_fps - # convert timestamps to frame indices frame_indices = [round(ts * average_fps) for ts in timestamps] - # retrieve frames based on indices frames_batch = decoder.get_frames_at(indices=frame_indices)