add: randomized streaming dataset

This commit is contained in:
fracapuano
2025-05-27 11:43:55 +02:00
parent f07887e8d1
commit fccbce9ef9
4 changed files with 339 additions and 13 deletions

View File

@@ -151,6 +151,10 @@ class LeRobotDatasetMetadata:
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@property
def url_root(self) -> str:
return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main"
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""

View File

@@ -0,0 +1,252 @@
import random
from pathlib import Path
from typing import Callable, Dict, Generator, Iterator
import datasets
import numpy as np
import torch
from datasets import load_dataset
from line_profiler import profile
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
check_version_compatibility,
item_to_torch,
)
from lerobot.common.datasets.video_utils import (
VideoDecoderCache,
decode_video_frames_torchcodec,
get_safe_default_codec,
)
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""LeRobotDataset with streaming capabilities.
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
rather than loaded entirely into memory. This is especially useful for large datasets that may
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
items, allowing us to access previous frames for delta timestamps without loading the entire
dataset into memory.
Example:
Basic usage:
```python
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
# Create a streaming dataset with delta timestamps
delta_timestamps = {
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
}
dataset = StreamingLeRobotDataset(
repo_id="your-dataset-repo-id",
delta_timestamps=delta_timestamps,
streaming=True,
buffer_size=1000,
)
# Iterate over the dataset
for i, item in enumerate(dataset):
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
# item will contain stacked frames according to delta_timestamps
if i >= 10:
break
```
"""
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
video_backend: str | None = "torchcodec",
streaming: bool = True,
buffer_size: int = 1000,
max_num_shards: int = 16,
seed: int = 42,
rng: np.random.Generator | None = None,
):
"""Initialize a StreamingLeRobotDataset.
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory to use for downloading/writing files.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list.
image_transforms (Callable | None, optional): Transform to apply to image data.
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
revision (str, optional): Git revision id (branch name, tag, or commit hash).
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
video_backend (str | None, optional): Video backend to use for decoding videos. Uses "torchcodec" by default.
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
seed (int, optional): Reproducibility random seed.
rng (np.random.Generator | None, optional): Random number generator.
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.seed = seed
self.rng = rng if rng is not None else np.random.default_rng(seed)
self.streaming = streaming
self.buffer_size = buffer_size
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = VideoDecoderCache()
# Unused attributes
self.image_writer = None
self.episode_buffer = None
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
self.hf_dataset = self.load_hf_dataset()
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
@property
def fps(self):
return self.meta.fps
@staticmethod
def _iter_random_indices(
rng: np.random.Generator, buffer_size: int, random_batch_size=1000
) -> Iterator[int]:
while True:
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
@staticmethod
def _infinite_generator_over_elements(elements: list[int]) -> Iterator[int]:
return (random.choice(list(elements)) for _ in iter(int, 1))
def load_hf_dataset(self) -> datasets.IterableDataset:
dataset = load_dataset(self.repo_id, split="train", streaming=self.streaming)
self.streaming_from_local = False
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
return dataset
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
# This buffer is populated while iterating on the dataset's shards
frames_buffer = []
idx_to_iterable_dataset = {
idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
for idx in range(self.num_shards)
}
try:
while available_shards := list(idx_to_iterable_dataset.keys()):
shard_key = next(self._infinite_generator_over_elements(available_shards))
dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on
for frame in self.make_frame(dataset):
if len(frames_buffer) == self.buffer_size:
i = next(buffer_indices_generator)
yield frames_buffer[i]
frames_buffer[i] = frame
else:
frames_buffer.append(frame)
break # random shard sampled, switch shard
except (
RuntimeError,
StopIteration,
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
# Remove exhausted shard
del idx_to_iterable_dataset[shard_key]
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
self.rng.shuffle(frames_buffer)
yield from frames_buffer
def _make_iterable_dataset(self, dataset: datasets.Dataset) -> Iterator:
return iter(dataset)
@profile
def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator:
"""Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator)
item = item_to_torch(item)
# Get episode index from the item
ep_idx = item["episode_index"]
# Load video frames, when needed
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"]
query_timestamps = self._get_query_timestamps(current_ts, None)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
# Add task as a string
task_idx = item["task_index"]
item["task"] = self.meta.tasks.iloc[task_idx].name
yield item
def _get_query_timestamps(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}"
frames = decode_video_frames_torchcodec(
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
)
item[vid_key] = frames.squeeze(0)
return item
# Example usage
if __name__ == "__main__":
repo_id = "lerobot/aloha_mobile_cabinet"
dataset = StreamingLeRobotDataset(repo_id)
for i, frame in enumerate(dataset):
print(frame)
if i > 10: # only stream first 10 frames
break

View File

@@ -24,7 +24,7 @@ from collections.abc import Iterator
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
from typing import Any, TypeVar
import datasets
import numpy as np
@@ -85,6 +85,8 @@ DEFAULT_FEATURES = {
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
T = TypeVar("T")
def get_parquet_file_size_in_mb(parquet_path):
metadata = pq.read_metadata(parquet_path)
@@ -898,3 +900,24 @@ def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
def item_to_torch(item: dict) -> dict:
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
This function is used to convert an item from a streaming dataset to PyTorch tensors.
Args:
item (dict): Dictionary of items from a dataset.
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
import numpy as np
import torch
for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item

View File

@@ -21,12 +21,15 @@ import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, ClassVar
from threading import Lock
from typing import Any, ClassVar, Dict, Literal, Optional
import fsspec
import pyarrow as pa
import torch
import torchvision
from datasets.features.features import register_feature
from line_profiler import profile
from PIL import Image
@@ -74,7 +77,7 @@ def decode_video_frames_torchvision(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
backend: str = "pyav",
backend: Literal["pyav", "video_reader"] = "pyav",
log_loaded_timestamps: bool = False,
) -> torch.Tensor:
"""Loads frames associated to the requested timestamps of a video
@@ -169,15 +172,62 @@ def decode_video_frames_torchvision(
return closest_frames
class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
def __init__(self):
self._cache: Dict[str, Any] = {}
self._lock = Lock()
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
video_path = str(video_path)
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path, client_kwargs={"trust_env": True}).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
self._cache[video_path] = decoder
return self._cache[video_path]
def clear(self):
"""Clear the cache."""
with self._lock:
self._cache.clear()
def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
return len(self._cache)
# Global instance
_default_decoder_cache = VideoDecoderCache()
@profile
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
decoder_cache: Optional[VideoDecoderCache] = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Args:
video_path: Path to the video file.
timestamps: List of timestamps to extract frames.
tolerance_s: Allowed deviation in seconds for frame retrieval.
log_loaded_timestamps: Whether to log loaded timestamps.
decoder_cache: Optional decoder cache instance. Uses default if None.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
@@ -186,23 +236,20 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
if decoder_cache is None:
decoder_cache = _default_decoder_cache
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
# initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = []
loaded_ts = []
loaded_frames = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)