Add Streaming Dataset (#1613)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Francesco Capuano
2025-09-15 14:08:01 +02:00
committed by GitHub
parent f55c6e89f0
commit 33cad37054
10 changed files with 1380 additions and 26 deletions

View File

@@ -0,0 +1,116 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script demonstrates how to train a Diffusion Policy on the PushT environment,
using a dataset processed in streaming mode.
Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
from pathlib import Path
import torch
from lerobot.configs.types import FeatureType
from lerobot.constants import ACTION
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
def main():
# Create a directory to store the training checkpoint.
output_directory = Path("outputs/train/example_streaming_dataset")
output_directory.mkdir(parents=True, exist_ok=True)
# Selects the "best" device available
device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
print(f"Using device: {device}")
training_steps = 10
log_freq = 1
dataset_id = (
"aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (:
)
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
# We can now instantiate our policy with this config and the dataset stats.
cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
# Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy.
# Here, we use delta-timestamps to only provide ground truth actions for supervision
delta_timestamps = {
ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)],
}
# Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched
# iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs
dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=16,
pin_memory=device.type != "cpu",
drop_last=True,
prefetch_factor=2, # loads batches with multiprocessing while policy trains
)
# Run training loop.
step = 0
done = False
while not done:
for batch in dataloader:
batch = {
k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v)
for k, v in batch.items()
}
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
# batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break
# Save a policy checkpoint.
policy.save_pretrained(output_directory)
if __name__ == "__main__":
main()

View File

@@ -37,6 +37,7 @@ class DatasetConfig:
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = field(default_factory=get_safe_default_codec)
streaming: bool = False
@dataclass

View File

@@ -52,3 +52,8 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu
# calibration dir
default_calibration_path = HF_LEROBOT_HOME / "calibration"
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
# streaming datasets
LOOKBACK_BACKTRACKTABLE = 100
LOOKAHEAD_BACKTRACKTABLE = 100

View File

@@ -25,6 +25,7 @@ from lerobot.datasets.lerobot_dataset import (
LeRobotDatasetMetadata,
MultiLeRobotDataset,
)
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.transforms import ImageTransforms
IMAGENET_STATS = {
@@ -87,15 +88,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
)
if not cfg.dataset.streaming:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
)
else:
dataset = StreamingLeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
max_num_shards=cfg.num_workers,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
dataset = MultiLeRobotDataset(

View File

@@ -129,6 +129,10 @@ class LeRobotDatasetMetadata:
ignore_patterns=ignore_patterns,
)
@property
def url_root(self) -> str:
return f"hf://datasets/{self.repo_id}"
@property
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""

View File

@@ -0,0 +1,535 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable, Generator, Iterator
from pathlib import Path
import datasets
import numpy as np
import torch
from datasets import load_dataset
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
Backtrackable,
LookAheadError,
LookBackError,
check_version_compatibility,
find_float_index,
get_delta_indices,
is_float_in_list,
item_to_torch,
safe_shard,
)
from lerobot.datasets.video_utils import (
VideoDecoderCache,
decode_video_frames_torchcodec,
)
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
"""LeRobotDataset with streaming capabilities.
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
rather than loaded entirely into memory. This is especially useful for large datasets that may
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
items, allowing us to access previous frames for delta timestamps without loading the entire
dataset into memory.
Example:
Basic usage:
```python
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
# Create a streaming dataset with delta timestamps
delta_timestamps = {
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
}
dataset = StreamingLeRobotDataset(
repo_id="your-dataset-repo-id",
delta_timestamps=delta_timestamps,
streaming=True,
buffer_size=1000,
)
# Iterate over the dataset
for i, item in enumerate(dataset):
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
# item will contain stacked frames according to delta_timestamps
if i >= 10:
break
```
"""
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
streaming: bool = True,
buffer_size: int = 1000,
max_num_shards: int = 16,
seed: int = 42,
rng: np.random.Generator | None = None,
shuffle: bool = True,
):
"""Initialize a StreamingLeRobotDataset.
Args:
repo_id (str): This is the repo id that will be used to fetch the dataset.
root (Path | None, optional): Local directory to use for downloading/writing files.
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
their episode_index in this list.
image_transforms (Callable | None, optional): Transform to apply to image data.
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
revision (str, optional): Git revision id (branch name, tag, or commit hash).
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
seed (int, optional): Reproducibility random seed.
rng (np.random.Generator | None, optional): Random number generator.
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.streaming_from_local = root is not None
self.image_transforms = image_transforms
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.seed = seed
self.rng = rng if rng is not None else np.random.default_rng(seed)
self.shuffle = shuffle
self.streaming = streaming
self.buffer_size = buffer_size
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
self.video_decoder_cache = None
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
self.delta_timestamps = None
self.delta_indices = None
if delta_timestamps is not None:
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
self.delta_timestamps = delta_timestamps
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
self.hf_dataset: datasets.IterableDataset = load_dataset(
self.repo_id if not self.streaming_from_local else str(self.root),
split="train",
streaming=self.streaming,
data_files="data/*/*.parquet",
revision=self.revision,
)
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
@property
def num_frames(self):
return self.meta.total_frames
@property
def num_episodes(self):
return self.meta.total_episodes
@property
def fps(self):
return self.meta.fps
@staticmethod
def _iter_random_indices(
rng: np.random.Generator, buffer_size: int, random_batch_size=100
) -> Iterator[int]:
while True:
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
@staticmethod
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
while True:
yield rng.choice(elements)
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
# The current sequential iteration is a bottleneck. A producer-consumer pattern
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
# in parallel, feeding a queue from which this iterator will yield processed items.
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
if self.video_decoder_cache is None:
self.video_decoder_cache = VideoDecoderCache()
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
idx_to_backtrack_dataset = {
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
for idx in range(self.num_shards)
}
# This buffer is populated while iterating on the dataset's shards
# the logic is to add 2 levels of randomness:
# (1) sample one shard at random from the ones available, and
# (2) sample one frame from the shard sampled at (1)
frames_buffer = []
while available_shards := list(idx_to_backtrack_dataset.keys()):
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
try:
for frame in self.make_frame(backtrack_dataset):
if len(frames_buffer) == self.buffer_size:
i = next(buffer_indices_generator) # samples a element from the buffer
yield frames_buffer[i]
frames_buffer[i] = frame
else:
frames_buffer.append(frame)
break # random shard sampled, switch shard
except (
RuntimeError,
StopIteration,
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
rng.shuffle(frames_buffer)
yield from frames_buffer
def _get_window_steps(
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
) -> tuple[int, int]:
if delta_timestamps is None:
return 1, 1
if not dynamic_bounds:
# Fix the windows
lookback = LOOKBACK_BACKTRACKTABLE
lookahead = LOOKAHEAD_BACKTRACKTABLE
else:
# Dynamically adjust the windows based on the given delta_timesteps
all_timestamps = sum(delta_timestamps.values(), [])
lookback = min(all_timestamps) * self.fps
lookahead = max(all_timestamps) * self.fps
# When lookback is >=0 it means no negative timesteps have been provided
lookback = 0 if lookback >= 0 else (lookback * -1)
return lookback, lookahead
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
def _make_timestamps_from_indices(
self, start_ts: float, indices: dict[str, list[int]] | None = None
) -> dict[str, list[float]]:
if indices is not None:
return {
key: (
start_ts + torch.tensor(indices[key]) / self.fps
).tolist() # NOTE: why not delta_timestamps directly?
for key in self.delta_timestamps
}
else:
return dict.fromkeys(self.meta.video_keys, [start_ts])
def _make_padding_camera_frame(self, camera_key: str):
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
def _get_video_frame_padding_mask(
self,
video_frames: dict[str, torch.Tensor],
query_timestamps: dict[str, list[float]],
original_timestamps: dict[str, list[float]],
) -> dict[str, torch.BoolTensor]:
padding_mask = {}
for video_key, timestamps in original_timestamps.items():
if video_key not in video_frames:
continue # only padding on video keys that are available
frames = []
mask = []
padding_frame = self._make_padding_camera_frame(video_key)
for ts in timestamps:
if is_float_in_list(ts, query_timestamps[video_key]):
idx = find_float_index(ts, query_timestamps[video_key])
frames.append(video_frames[video_key][idx, :])
mask.append(False)
else:
frames.append(padding_frame)
mask.append(True)
padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
return padding_mask
def make_frame(
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
) -> Generator:
"""Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator)
item = item_to_torch(item)
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
# Get episode index from the item
ep_idx = item["episode_index"]
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
current_ts = item["index"] / self.fps
episode_boundaries_ts = {
key: (
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
)
for key in self.meta.video_keys
}
# Apply delta querying logic if necessary
if self.delta_indices is not None:
query_result, padding = self._get_delta_frames(dataset_iterator, item)
updates.append(query_result)
updates.append(padding)
# Load video frames, when needed
if len(self.meta.video_keys) > 0:
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
# Some timestamps might not result available considering the episode's boundaries
query_timestamps = self._get_query_timestamps(
current_ts, self.delta_indices, episode_boundaries_ts
)
video_frames = self._query_videos(query_timestamps, ep_idx)
if self.image_transforms is not None:
image_keys = self.meta.camera_keys
for cam in image_keys:
video_frames[cam] = self.image_transforms(video_frames[cam])
updates.append(video_frames)
if self.delta_indices is not None:
# We always return the same number of frames. Unavailable frames are padded.
padding_mask = self._get_video_frame_padding_mask(
video_frames, query_timestamps, original_timestamps
)
updates.append(padding_mask)
result = item.copy()
for update in updates:
result.update(update)
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
yield result
def _get_query_timestamps(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = keys_to_timestamps[key]
# Clamp out timesteps outside of episode boundaries
query_timestamps[key] = torch.clamp(
torch.tensor(timestamps), *episode_boundaries_ts[key]
).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
item = {}
for video_key, query_ts in query_timestamps.items():
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec(
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
)
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
return item
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
# TODO(fracapuano): Modularize this function, refactor the code
"""Get frames with delta offsets using the backtrackable iterator.
Args:
current_item (dict): Current item from the iterator.
ep_idx (int): Episode index.
Returns:
tuple: (query_result, padding) - frames at delta offsets and padding info.
"""
current_episode_idx = current_item["episode_index"]
# Prepare results
query_result = {}
padding = {}
for key, delta_indices in self.delta_indices.items():
if key in self.meta.video_keys:
continue # visual frames are decoded separately
target_frames = []
is_pad = []
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
delta_results = {}
# Separate and sort deltas by difficulty (easier operations first)
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
zero_deltas = [d for d in delta_indices if d == 0]
# Process zero deltas (current frame)
for delta in zero_deltas:
delta_results[delta] = (
current_item[key],
False,
)
# Process negative deltas in order of increasing difficulty
lookback_failed = False
last_successful_frame = current_item[key]
for delta in negative_deltas:
if lookback_failed:
delta_results[delta] = (last_successful_frame, True)
continue
try:
steps_back = abs(delta)
if dataset_iterator.can_peek_back(steps_back):
past_item = dataset_iterator.peek_back(steps_back)
past_item = item_to_torch(past_item)
if past_item["episode_index"] == current_episode_idx:
delta_results[delta] = (past_item[key], False)
last_successful_frame = past_item[key]
else:
raise LookBackError("Retrieved frame is from different episode!")
else:
raise LookBackError("Cannot go back further than the history buffer!")
except LookBackError:
delta_results[delta] = (last_successful_frame, True)
lookback_failed = True # All subsequent negative deltas will also fail
# Process positive deltas in order of increasing difficulty
lookahead_failed = False
last_successful_frame = current_item[key]
for delta in positive_deltas:
if lookahead_failed:
delta_results[delta] = (last_successful_frame, True)
continue
try:
if dataset_iterator.can_peek_ahead(delta):
future_item = dataset_iterator.peek_ahead(delta)
future_item = item_to_torch(future_item)
if future_item["episode_index"] == current_episode_idx:
delta_results[delta] = (future_item[key], False)
last_successful_frame = future_item[key]
else:
raise LookAheadError("Retrieved frame is from different episode!")
else:
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
except LookAheadError:
delta_results[delta] = (last_successful_frame, True)
lookahead_failed = True # All subsequent positive deltas will also fail
# Reconstruct original order for stacking
for delta in delta_indices:
frame, is_padded = delta_results[delta]
# add batch dimension for stacking
target_frames.append(frame) # frame.unsqueeze(0))
is_pad.append(is_padded)
# Stack frames and add to results
if target_frames:
query_result[key] = torch.stack(target_frames)
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
return query_result, padding
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
"""
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
Raises:
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
"""
if delta_timestamps is None:
return
# Get all available feature keys from the dataset metadata
available_features = set(self.meta.features.keys())
# Get all keys from delta_timestamps
delta_keys = set(delta_timestamps.keys())
# Find any keys that don't correspond to features
invalid_keys = delta_keys - available_features
if invalid_keys:
raise ValueError(
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
f"Available features are: {sorted(available_features)}"
)

View File

@@ -17,10 +17,11 @@ import contextlib
import importlib.resources
import json
import logging
from collections.abc import Iterator
from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any
from typing import Any, Deque, Generic, TypeVar
import datasets
import numpy as np
@@ -86,6 +87,8 @@ DEFAULT_FEATURES = {
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
T = TypeVar("T")
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
metadata = pq.read_metadata(parquet_path)
@@ -776,3 +779,230 @@ def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
def item_to_torch(item: dict) -> dict:
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
This function is used to convert an item from a streaming dataset to PyTorch tensors.
Args:
item (dict): Dictionary of items from a dataset.
Returns:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
def is_float_in_list(target, float_list, threshold=1e-6):
return any(abs(target - x) <= threshold for x in float_list)
def find_float_index(target, float_list, threshold=1e-6):
for i, x in enumerate(float_list):
if abs(target - x) <= threshold:
return i
return -1
class LookBackError(Exception):
"""
Exception raised when trying to look back in the history of a Backtrackable object.
"""
pass
class LookAheadError(Exception):
"""
Exception raised when trying to look ahead in the future of a Backtrackable object.
"""
pass
class Backtrackable(Generic[T]):
"""
Wrap any iterator/iterable so you can step back up to `history` items
and look ahead up to `lookahead` items.
This is useful for streaming datasets where you need to access previous and future items
but can't load the entire dataset into memory.
Example:
-------
```python
ds = load_dataset("c4", "en", streaming=True, split="train")
rev = Backtrackable(ds, history=3, lookahead=2)
x0 = next(rev) # forward
x1 = next(rev)
x2 = next(rev)
# Look ahead
x3_peek = rev.peek_ahead(1) # next item without moving cursor
x4_peek = rev.peek_ahead(2) # two items ahead
# Look back
x1_again = rev.peek_back(1) # previous item without moving cursor
x0_again = rev.peek_back(2) # two items back
# Move backward
x1_back = rev.prev() # back one step
next(rev) # returns x2, continues forward from where we were
```
"""
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
if history < 1:
raise ValueError("history must be >= 1")
if lookahead <= 0:
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: Deque[T] = deque(maxlen=history)
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead
def __iter__(self) -> "Backtrackable[T]":
return self
def __next__(self) -> T:
# If we've stepped back, consume from back buffer first
if self._cursor < 0: # -1 means "last item", etc.
self._cursor += 1
return self._back_buf[self._cursor]
# If we have items in the ahead buffer, use them first
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
# Add current item to back buffer and reset cursor
self._back_buf.append(item)
self._cursor = 0
return item
def prev(self) -> T:
"""
Step one item back in history and return it.
Raises IndexError if already at the oldest buffered item.
"""
if len(self._back_buf) + self._cursor <= 1:
raise LookBackError("At start of history")
self._cursor -= 1
return self._back_buf[self._cursor]
def peek_back(self, n: int = 1) -> T:
"""
Look `n` items back (n=1 == previous item) without moving the cursor.
"""
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
raise LookBackError("peek_back distance out of range")
return self._back_buf[self._cursor - (n + 1)]
def peek_ahead(self, n: int = 1) -> T:
"""
Look `n` items ahead (n=1 == next item) without moving the cursor.
Fills the ahead buffer if necessary.
"""
if n < 1:
raise LookAheadError("peek_ahead distance must be 1 or more")
elif n > self._lookahead:
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
# Fill ahead buffer if we don't have enough items
while len(self._ahead_buf) < n:
try:
item = next(self._source)
self._ahead_buf.append(item)
except StopIteration as err:
raise LookAheadError("peek_ahead: not enough items in source") from err
return self._ahead_buf[n - 1]
def history(self) -> list[T]:
"""
Return a copy of the buffered history (most recent last).
The list length ≤ `history` argument passed at construction.
"""
if self._cursor == 0:
return list(self._back_buf)
# When cursor<0, slice so the order remains chronological
return list(self._back_buf)[: self._cursor or None]
def lookahead_buffer(self) -> list[T]:
"""
Return a copy of the current lookahead buffer.
"""
return list(self._ahead_buf)
def can_peek_back(self, steps: int = 1) -> bool:
"""
Check if we can go back `steps` items without raising an IndexError.
"""
return steps <= len(self._back_buf) + self._cursor
def can_peek_ahead(self, steps: int = 1) -> bool:
"""
Check if we can peek ahead `steps` items.
This may involve trying to fill the ahead buffer.
"""
if self._lookahead > 0 and steps > self._lookahead:
return False
# Try to fill ahead buffer to check if we can peek that far
try:
while len(self._ahead_buf) < steps:
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
return False
item = next(self._source)
self._ahead_buf.append(item)
return True
except StopIteration:
return False
def reset_cursor(self) -> None:
"""
Reset cursor to the most recent position (equivalent to calling next()
until you're back to the latest item).
"""
self._cursor = 0
def clear_ahead_buffer(self) -> None:
"""
Clear the ahead buffer, discarding any pre-fetched items.
"""
self._ahead_buf.clear()
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
"""
Switch the source of the backtrackable to a new iterable, keeping the history.
This is useful when iterating over a sequence of datasets. The history from the
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
to the present.
"""
self._source = iter(new_source)
self.clear_ahead_buffer()
self.reset_cursor()
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
"""
Safe shards the dataset.
"""
shard_idx = min(dataset.num_shards, index + 1) - 1
return dataset.shard(num_shards, index=shard_idx)

View File

@@ -21,9 +21,11 @@ import tempfile
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from typing import Any, ClassVar
import av
import fsspec
import pyarrow as pa
import torch
import torchvision
@@ -169,15 +171,68 @@ def decode_video_frames_torchvision(
return closest_frames
class VideoDecoderCache:
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
def __init__(self):
self._cache: dict[str, tuple[Any, Any]] = {}
self._lock = Lock()
def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one."""
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
video_path = str(video_path)
with self._lock:
if video_path not in self._cache:
file_handle = fsspec.open(video_path).__enter__()
decoder = VideoDecoder(file_handle, seek_mode="approximate")
self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0]
def clear(self):
"""Clear the cache and close file handles."""
with self._lock:
for _, file_handle in self._cache.values():
file_handle.close()
self._cache.clear()
def size(self) -> int:
"""Return the number of cached decoders."""
with self._lock:
return len(self._cache)
class FrameTimestampError(ValueError):
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
pass
_default_decoder_cache = VideoDecoderCache()
def decode_video_frames_torchcodec(
video_path: Path | str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
decoder_cache: VideoDecoderCache | None = None,
) -> torch.Tensor:
"""Loads frames associated with the requested timestamps of a video using torchcodec.
Args:
video_path: Path to the video file.
timestamps: List of timestamps to extract frames.
tolerance_s: Allowed deviation in seconds for frame retrieval.
log_loaded_timestamps: Whether to log loaded timestamps.
decoder_cache: Optional decoder cache instance. Uses default if None.
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
@@ -186,27 +241,24 @@ def decode_video_frames_torchcodec(
and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes.
"""
if decoder_cache is None:
decoder_cache = _default_decoder_cache
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
# Use cached decoder instead of creating new one each time
decoder = decoder_cache.get_decoder(str(video_path))
# initialize video decoder
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
loaded_frames = []
loaded_ts = []
loaded_frames = []
# get metadata for frame information
metadata = decoder.metadata
average_fps = metadata.average_fps
# convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
loaded_frames.append(frame)
loaded_ts.append(pts.item())
if log_loaded_timestamps:
@@ -237,10 +289,14 @@ def decode_video_frames_torchcodec(
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255
# convert to float32 in [0,1] range
closest_frames = (closest_frames / 255.0).type(torch.float32)
if not len(timestamps) == len(closest_frames):
raise FrameTimestampError(
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
)
assert len(timestamps) == len(closest_frames)
return closest_frames

View File

@@ -179,10 +179,11 @@ def train(cfg: TrainPipelineConfig):
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
shuffle=shuffle and not cfg.dataset.streaming,
sampler=sampler,
pin_memory=device.type == "cuda",
drop_last=False,
prefetch_factor=2,
)
dl_iter = cycle(dataloader)
@@ -208,6 +209,9 @@ def train(cfg: TrainPipelineConfig):
for key in batch:
if isinstance(batch[key], torch.Tensor):
if batch[key].dtype != torch.bool:
batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key]
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(

View File

@@ -0,0 +1,391 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.datasets.utils import safe_shard
from tests.fixtures.constants import DUMMY_REPO_ID
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
rng = np.random.default_rng(streaming_ds.seed)
buffer_size = streaming_ds.buffer_size
num_shards = streaming_ds.num_shards
shards_indices = []
for shard_idx in range(num_shards):
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
shard_indices = [item["index"] for item in shard]
shards_indices.append(shard_indices)
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
frames_buffer = []
expected_indices = []
while shard_iterators: # While there are still available shards
available_shard_keys = list(shard_iterators.keys())
if not available_shard_keys:
break
# Call _infinite_generator_over_elements with current available shards (key difference!)
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
try:
frame_index = next(shard_iterators[shard_key])
if len(frames_buffer) == buffer_size:
i = next(buffer_indices_generator)
expected_indices.append(frames_buffer[i])
frames_buffer[i] = frame_index
else:
frames_buffer.append(frame_index)
except StopIteration:
del shard_iterators[shard_key] # Remove exhausted shard
rng.shuffle(frames_buffer)
expected_indices.extend(frames_buffer)
return expected_indices
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
"""Test if are correctly accessed"""
ds_num_frames = 400
ds_num_episodes = 10
buffer_size = 100
local_path = tmp_path / "test"
repo_id = f"{DUMMY_REPO_ID}"
ds = lerobot_dataset_factory(
root=local_path,
repo_id=repo_id,
total_episodes=ds_num_episodes,
total_frames=ds_num_frames,
)
streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size))
key_checks = []
for _ in range(ds_num_frames):
streaming_frame = next(streaming_ds)
frame_idx = streaming_frame["index"]
target_frame = ds[frame_idx]
for key in streaming_frame:
left = streaming_frame[key]
right = target_frame[key]
if isinstance(left, str):
check = left == right
elif isinstance(left, torch.Tensor):
check = torch.allclose(left, right) and left.shape == right.shape
elif isinstance(left, float):
check = left == right.item() # right is a torch.Tensor
key_checks.append((key, check))
assert all(t[1] for t in key_checks), (
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})"
)
@pytest.mark.parametrize(
"shuffle",
[False, True],
)
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
ds_num_frames = 400
ds_num_episodes = 10
buffer_size = 100
seed = 42
n_epochs = 3
local_path = tmp_path / "test"
repo_id = f"{DUMMY_REPO_ID}"
lerobot_dataset_factory(
root=local_path,
repo_id=repo_id,
total_episodes=ds_num_episodes,
total_frames=ds_num_frames,
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
expected_indices = get_frames_expected_order(streaming_ds)
for _ in range(n_epochs):
streaming_indices = [frame["index"] for frame in streaming_ds]
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
)
if shuffle:
assert not frames_match
else:
assert frames_match
@pytest.mark.parametrize(
"shuffle",
[False, True],
)
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
ds_num_frames = 100
ds_num_episodes = 10
buffer_size = 10
seed = 42
n_epochs = 3
data_file_size_mb = 0.001
chunks_size = 1
local_path = tmp_path / "test"
repo_id = f"{DUMMY_REPO_ID}-ciao"
lerobot_dataset_factory(
root=local_path,
repo_id=repo_id,
total_episodes=ds_num_episodes,
total_frames=ds_num_frames,
data_files_size_in_mb=data_file_size_mb,
chunks_size=chunks_size,
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
buffer_size=buffer_size,
seed=seed,
shuffle=shuffle,
max_num_shards=4,
)
first_epoch_indices = [frame["index"] for frame in streaming_ds]
expected_indices = get_frames_expected_order(streaming_ds)
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
for _ in range(n_epochs):
streaming_indices = [
frame["index"] for frame in streaming_ds
] # NOTE: this is the same as first_epoch_indices
frames_match = all(
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
)
if shuffle:
assert not frames_match
else:
assert frames_match
@pytest.mark.parametrize(
"state_deltas, action_deltas",
[
([-1, -0.5, -0.20, 0], [0, 1, 2, 3]),
([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
([-2, -1, -0.5, 0], [0, 1, 2, 3]),
([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
],
)
def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas):
ds_num_frames = 500
ds_num_episodes = 10
buffer_size = 100
seed = 42
local_path = tmp_path / "test"
repo_id = f"{DUMMY_REPO_ID}-ciao"
camera_key = "phone"
delta_timestamps = {
camera_key: state_deltas,
"state": state_deltas,
"action": action_deltas,
}
ds = lerobot_dataset_factory(
root=local_path,
repo_id=repo_id,
total_episodes=ds_num_episodes,
total_frames=ds_num_frames,
delta_timestamps=delta_timestamps,
)
streaming_ds = iter(
StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
buffer_size=buffer_size,
seed=seed,
shuffle=False,
delta_timestamps=delta_timestamps,
)
)
for i in range(ds_num_frames):
streaming_frame = next(streaming_ds)
frame_idx = streaming_frame["index"]
target_frame = ds[frame_idx]
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
)
key_checks = []
for key in streaming_frame:
left = streaming_frame[key]
right = target_frame[key]
if isinstance(left, str):
check = left == right
elif isinstance(left, torch.Tensor):
if (
key not in ds.meta.camera_keys
and "is_pad" not in key
and f"{key}_is_pad" in streaming_frame
):
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
left = left[~streaming_frame[f"{key}_is_pad"]]
right = right[~target_frame[f"{key}_is_pad"]]
check = torch.allclose(left, right) and left.shape == right.shape
key_checks.append((key, check))
assert all(t[1] for t in key_checks), (
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
)
@pytest.mark.parametrize(
"state_deltas, action_deltas",
[
([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]),
([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]),
([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
],
)
def test_frames_with_delta_consistency_with_shards(
tmp_path, lerobot_dataset_factory, state_deltas, action_deltas
):
ds_num_frames = 100
ds_num_episodes = 10
buffer_size = 10
data_file_size_mb = 0.001
chunks_size = 1
seed = 42
local_path = tmp_path / "test"
repo_id = f"{DUMMY_REPO_ID}-ciao"
camera_key = "phone"
delta_timestamps = {
camera_key: state_deltas,
"state": state_deltas,
"action": action_deltas,
}
ds = lerobot_dataset_factory(
root=local_path,
repo_id=repo_id,
total_episodes=ds_num_episodes,
total_frames=ds_num_frames,
delta_timestamps=delta_timestamps,
data_files_size_in_mb=data_file_size_mb,
chunks_size=chunks_size,
)
streaming_ds = StreamingLeRobotDataset(
repo_id=repo_id,
root=local_path,
buffer_size=buffer_size,
seed=seed,
shuffle=False,
delta_timestamps=delta_timestamps,
max_num_shards=4,
)
iter(streaming_ds)
num_shards = 4
shards_indices = []
for shard_idx in range(num_shards):
shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards)
shard_indices = [item["index"] for item in shard]
shards_indices.append(shard_indices)
streaming_ds = iter(streaming_ds)
for i in range(ds_num_frames):
streaming_frame = next(streaming_ds)
frame_idx = streaming_frame["index"]
target_frame = ds[frame_idx]
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
)
key_checks = []
for key in streaming_frame:
left = streaming_frame[key]
right = target_frame[key]
if isinstance(left, str):
check = left == right
elif isinstance(left, torch.Tensor):
if (
key not in ds.meta.camera_keys
and "is_pad" not in key
and f"{key}_is_pad" in streaming_frame
):
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
left = left[~streaming_frame[f"{key}_is_pad"]]
right = right[~target_frame[f"{key}_is_pad"]]
check = torch.allclose(left, right) and left.shape == right.shape
elif isinstance(left, float):
check = left == right.item() # right is a torch.Tensor
key_checks.append((key, check))
assert all(t[1] for t in key_checks), (
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
)