forked from tangger/lerobot
add: randomized streaming dataset
This commit is contained in:
@@ -151,6 +151,10 @@ class LeRobotDatasetMetadata:
|
|||||||
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
|
||||||
return Path(fpath)
|
return Path(fpath)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url_root(self) -> str:
|
||||||
|
return f"https://huggingface.co/datasets/{self.repo_id}/resolve/main"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_path(self) -> str:
|
def data_path(self) -> str:
|
||||||
"""Formattable string for the parquet files."""
|
"""Formattable string for the parquet files."""
|
||||||
|
|||||||
252
lerobot/common/datasets/streaming_dataset.py
Normal file
252
lerobot/common/datasets/streaming_dataset.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, Generator, Iterator
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from line_profiler import profile
|
||||||
|
|
||||||
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
check_version_compatibility,
|
||||||
|
item_to_torch,
|
||||||
|
)
|
||||||
|
from lerobot.common.datasets.video_utils import (
|
||||||
|
VideoDecoderCache,
|
||||||
|
decode_video_frames_torchcodec,
|
||||||
|
get_safe_default_codec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||||
|
"""LeRobotDataset with streaming capabilities.
|
||||||
|
|
||||||
|
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||||
|
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||||
|
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||||
|
|
||||||
|
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||||
|
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||||
|
dataset into memory.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Basic usage:
|
||||||
|
```python
|
||||||
|
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
|
|
||||||
|
# Create a streaming dataset with delta timestamps
|
||||||
|
delta_timestamps = {
|
||||||
|
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||||
|
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = StreamingLeRobotDataset(
|
||||||
|
repo_id="your-dataset-repo-id",
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
streaming=True,
|
||||||
|
buffer_size=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate over the dataset
|
||||||
|
for i, item in enumerate(dataset):
|
||||||
|
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||||
|
# item will contain stacked frames according to delta_timestamps
|
||||||
|
if i >= 10:
|
||||||
|
break
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
root: str | Path | None = None,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
image_transforms: Callable | None = None,
|
||||||
|
tolerance_s: float = 1e-4,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_cache_sync: bool = False,
|
||||||
|
video_backend: str | None = "torchcodec",
|
||||||
|
streaming: bool = True,
|
||||||
|
buffer_size: int = 1000,
|
||||||
|
max_num_shards: int = 16,
|
||||||
|
seed: int = 42,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
):
|
||||||
|
"""Initialize a StreamingLeRobotDataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||||
|
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||||
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
|
their episode_index in this list.
|
||||||
|
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||||
|
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
|
||||||
|
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||||
|
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||||
|
video_backend (str | None, optional): Video backend to use for decoding videos. Uses "torchcodec" by default.
|
||||||
|
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||||
|
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||||
|
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||||
|
seed (int, optional): Reproducibility random seed.
|
||||||
|
rng (np.random.Generator | None, optional): Random number generator.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
self.episodes = episodes
|
||||||
|
self.tolerance_s = tolerance_s
|
||||||
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
|
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
|
self.seed = seed
|
||||||
|
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||||
|
|
||||||
|
self.streaming = streaming
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
|
||||||
|
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||||
|
self.video_decoder_cache = VideoDecoderCache()
|
||||||
|
|
||||||
|
# Unused attributes
|
||||||
|
self.image_writer = None
|
||||||
|
self.episode_buffer = None
|
||||||
|
|
||||||
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
self.meta = LeRobotDatasetMetadata(
|
||||||
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
|
)
|
||||||
|
# Check version
|
||||||
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||||
|
|
||||||
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self):
|
||||||
|
return self.meta.fps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _iter_random_indices(
|
||||||
|
rng: np.random.Generator, buffer_size: int, random_batch_size=1000
|
||||||
|
) -> Iterator[int]:
|
||||||
|
while True:
|
||||||
|
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infinite_generator_over_elements(elements: list[int]) -> Iterator[int]:
|
||||||
|
return (random.choice(list(elements)) for _ in iter(int, 1))
|
||||||
|
|
||||||
|
def load_hf_dataset(self) -> datasets.IterableDataset:
|
||||||
|
dataset = load_dataset(self.repo_id, split="train", streaming=self.streaming)
|
||||||
|
self.streaming_from_local = False
|
||||||
|
|
||||||
|
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||||
|
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
|
||||||
|
|
||||||
|
# This buffer is populated while iterating on the dataset's shards
|
||||||
|
frames_buffer = []
|
||||||
|
idx_to_iterable_dataset = {
|
||||||
|
idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
|
||||||
|
for idx in range(self.num_shards)
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
while available_shards := list(idx_to_iterable_dataset.keys()):
|
||||||
|
shard_key = next(self._infinite_generator_over_elements(available_shards))
|
||||||
|
dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on
|
||||||
|
for frame in self.make_frame(dataset):
|
||||||
|
if len(frames_buffer) == self.buffer_size:
|
||||||
|
i = next(buffer_indices_generator)
|
||||||
|
yield frames_buffer[i]
|
||||||
|
frames_buffer[i] = frame
|
||||||
|
else:
|
||||||
|
frames_buffer.append(frame)
|
||||||
|
break # random shard sampled, switch shard
|
||||||
|
|
||||||
|
except (
|
||||||
|
RuntimeError,
|
||||||
|
StopIteration,
|
||||||
|
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
|
||||||
|
# Remove exhausted shard
|
||||||
|
del idx_to_iterable_dataset[shard_key]
|
||||||
|
|
||||||
|
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||||
|
self.rng.shuffle(frames_buffer)
|
||||||
|
yield from frames_buffer
|
||||||
|
|
||||||
|
def _make_iterable_dataset(self, dataset: datasets.Dataset) -> Iterator:
|
||||||
|
return iter(dataset)
|
||||||
|
|
||||||
|
@profile
|
||||||
|
def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator:
|
||||||
|
"""Makes a frame starting from a dataset iterator"""
|
||||||
|
item = next(dataset_iterator)
|
||||||
|
item = item_to_torch(item)
|
||||||
|
|
||||||
|
# Get episode index from the item
|
||||||
|
ep_idx = item["episode_index"]
|
||||||
|
|
||||||
|
# Load video frames, when needed
|
||||||
|
if len(self.meta.video_keys) > 0:
|
||||||
|
current_ts = item["timestamp"]
|
||||||
|
query_timestamps = self._get_query_timestamps(current_ts, None)
|
||||||
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
|
item = {**video_frames, **item}
|
||||||
|
|
||||||
|
# Add task as a string
|
||||||
|
task_idx = item["task_index"]
|
||||||
|
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||||
|
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def _get_query_timestamps(
|
||||||
|
self,
|
||||||
|
current_ts: float,
|
||||||
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query_timestamps = {}
|
||||||
|
for key in self.meta.video_keys:
|
||||||
|
if query_indices is not None and key in query_indices:
|
||||||
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
|
else:
|
||||||
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
|
return query_timestamps
|
||||||
|
|
||||||
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||||
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||||
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||||
|
the main process and a subprocess fails to access it.
|
||||||
|
"""
|
||||||
|
item = {}
|
||||||
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
|
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||||
|
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}"
|
||||||
|
frames = decode_video_frames_torchcodec(
|
||||||
|
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||||
|
)
|
||||||
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||||
|
dataset = StreamingLeRobotDataset(repo_id)
|
||||||
|
|
||||||
|
for i, frame in enumerate(dataset):
|
||||||
|
print(frame)
|
||||||
|
|
||||||
|
if i > 10: # only stream first 10 frames
|
||||||
|
break
|
||||||
@@ -24,7 +24,7 @@ from collections.abc import Iterator
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -85,6 +85,8 @@ DEFAULT_FEATURES = {
|
|||||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_file_size_in_mb(parquet_path):
|
def get_parquet_file_size_in_mb(parquet_path):
|
||||||
metadata = pq.read_metadata(parquet_path)
|
metadata = pq.read_metadata(parquet_path)
|
||||||
@@ -898,3 +900,24 @@ def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys
|
|||||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||||
else:
|
else:
|
||||||
df.to_parquet(path)
|
df.to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
|
def item_to_torch(item: dict) -> dict:
|
||||||
|
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||||
|
|
||||||
|
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item (dict): Dictionary of items from a dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
for key, val in item.items():
|
||||||
|
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||||
|
# Convert numpy arrays and lists to torch tensors
|
||||||
|
item[key] = torch.tensor(val)
|
||||||
|
return item
|
||||||
|
|||||||
@@ -21,12 +21,15 @@ import warnings
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, ClassVar
|
from threading import Lock
|
||||||
|
from typing import Any, ClassVar, Dict, Literal, Optional
|
||||||
|
|
||||||
|
import fsspec
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
|
from line_profiler import profile
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +77,7 @@ def decode_video_frames_torchvision(
|
|||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "pyav",
|
backend: Literal["pyav", "video_reader"] = "pyav",
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Loads frames associated to the requested timestamps of a video
|
"""Loads frames associated to the requested timestamps of a video
|
||||||
@@ -169,15 +172,62 @@ def decode_video_frames_torchvision(
|
|||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDecoderCache:
|
||||||
|
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cache: Dict[str, Any] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def get_decoder(self, video_path: str):
|
||||||
|
"""Get a cached decoder or create a new one."""
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
else:
|
||||||
|
raise ImportError("torchcodec is required but not available.")
|
||||||
|
|
||||||
|
video_path = str(video_path)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if video_path not in self._cache:
|
||||||
|
file_handle = fsspec.open(video_path, client_kwargs={"trust_env": True}).__enter__()
|
||||||
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||||
|
self._cache[video_path] = decoder
|
||||||
|
|
||||||
|
return self._cache[video_path]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the cache."""
|
||||||
|
with self._lock:
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Return the number of cached decoders."""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
_default_decoder_cache = VideoDecoderCache()
|
||||||
|
|
||||||
|
|
||||||
|
@profile
|
||||||
def decode_video_frames_torchcodec(
|
def decode_video_frames_torchcodec(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
device: str = "cpu",
|
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
|
decoder_cache: Optional[VideoDecoderCache] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file.
|
||||||
|
timestamps: List of timestamps to extract frames.
|
||||||
|
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
||||||
|
log_loaded_timestamps: Whether to log loaded timestamps.
|
||||||
|
decoder_cache: Optional decoder cache instance. Uses default if None.
|
||||||
|
|
||||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||||
|
|
||||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
@@ -186,23 +236,20 @@ def decode_video_frames_torchcodec(
|
|||||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
"""
|
"""
|
||||||
|
if decoder_cache is None:
|
||||||
|
decoder_cache = _default_decoder_cache
|
||||||
|
|
||||||
if importlib.util.find_spec("torchcodec"):
|
# Use cached decoder instead of creating new one each time
|
||||||
from torchcodec.decoders import VideoDecoder
|
decoder = decoder_cache.get_decoder(str(video_path))
|
||||||
else:
|
|
||||||
raise ImportError("torchcodec is required but not available.")
|
|
||||||
|
|
||||||
# initialize video decoder
|
|
||||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
|
||||||
loaded_frames = []
|
|
||||||
loaded_ts = []
|
loaded_ts = []
|
||||||
|
loaded_frames = []
|
||||||
|
|
||||||
# get metadata for frame information
|
# get metadata for frame information
|
||||||
metadata = decoder.metadata
|
metadata = decoder.metadata
|
||||||
average_fps = metadata.average_fps
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
# convert timestamps to frame indices
|
# convert timestamps to frame indices
|
||||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
# retrieve frames based on indices
|
# retrieve frames based on indices
|
||||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user