diff --git a/examples/5_train_with_streaming.py b/examples/5_train_with_streaming.py new file mode 100644 index 00000000..17818410 --- /dev/null +++ b/examples/5_train_with_streaming.py @@ -0,0 +1,116 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This script demonstrates how to train a Diffusion Policy on the PushT environment, +using a dataset processed in streaming mode. + +Once you have trained a model with this script, you can try to evaluate it on +examples/2_evaluate_pretrained_policy.py +""" + +from pathlib import Path + +import torch + +from lerobot.configs.types import FeatureType +from lerobot.constants import ACTION +from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.policies.act.modeling_act import ACTPolicy + + +def main(): + # Create a directory to store the training checkpoint. + output_directory = Path("outputs/train/example_streaming_dataset") + output_directory.mkdir(parents=True, exist_ok=True) + + # Selects the "best" device available + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("mps") + if torch.backends.mps.is_available() + else torch.device("cpu") + ) + print(f"Using device: {device}") + + training_steps = 10 + log_freq = 1 + + dataset_id = ( + "aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (: + ) + dataset_metadata = LeRobotDatasetMetadata(dataset_id) + features = dataset_to_policy_features(dataset_metadata.features) + output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + input_features = {key: ft for key, ft in features.items() if key not in output_features} + + # We can now instantiate our policy with this config and the dataset stats. + cfg = ACTConfig(input_features=input_features, output_features=output_features) + policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats) + policy.train() + policy.to(device) + + # Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy. + # Here, we use delta-timestamps to only provide ground truth actions for supervision + delta_timestamps = { + ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)], + } + + # Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched + # iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs + dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3) + + optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=16, + pin_memory=device.type != "cpu", + drop_last=True, + prefetch_factor=2, # loads batches with multiprocessing while policy trains + ) + + # Run training loop. + step = 0 + done = False + while not done: + for batch in dataloader: + batch = { + k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v) + for k, v in batch.items() + } + batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + # batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + loss, _ = policy.forward(batch) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if step % log_freq == 0: + print(f"step: {step} loss: {loss.item():.3f}") + step += 1 + if step >= training_steps: + done = True + break + + # Save a policy checkpoint. + policy.save_pretrained(output_directory) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 53cfe58e..1bc2b8d1 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -37,6 +37,7 @@ class DatasetConfig: revision: str | None = None use_imagenet_stats: bool = True video_backend: str = field(default_factory=get_safe_default_codec) + streaming: bool = False @dataclass diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 30777239..382435a9 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -52,3 +52,8 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu # calibration dir default_calibration_path = HF_LEROBOT_HOME / "calibration" HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser() + + +# streaming datasets +LOOKBACK_BACKTRACKTABLE = 100 +LOOKAHEAD_BACKTRACKTABLE = 100 diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index e06650bc..a71e978b 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -25,6 +25,7 @@ from lerobot.datasets.lerobot_dataset import ( LeRobotDatasetMetadata, MultiLeRobotDataset, ) +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset from lerobot.datasets.transforms import ImageTransforms IMAGENET_STATS = { @@ -87,15 +88,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision ) delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) - dataset = LeRobotDataset( - cfg.dataset.repo_id, - root=cfg.dataset.root, - episodes=cfg.dataset.episodes, - delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - revision=cfg.dataset.revision, - video_backend=cfg.dataset.video_backend, - ) + if not cfg.dataset.streaming: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=cfg.dataset.revision, + video_backend=cfg.dataset.video_backend, + ) + else: + dataset = StreamingLeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=cfg.dataset.episodes, + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + revision=cfg.dataset.revision, + max_num_shards=cfg.num_workers, + ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") dataset = MultiLeRobotDataset( diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index ceefcf05..9cd4b6bf 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -129,6 +129,10 @@ class LeRobotDatasetMetadata: ignore_patterns=ignore_patterns, ) + @property + def url_root(self) -> str: + return f"hf://datasets/{self.repo_id}" + @property def _version(self) -> packaging.version.Version: """Codebase version used to create this dataset.""" diff --git a/src/lerobot/datasets/streaming_dataset.py b/src/lerobot/datasets/streaming_dataset.py new file mode 100644 index 00000000..e354c406 --- /dev/null +++ b/src/lerobot/datasets/streaming_dataset.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable, Generator, Iterator +from pathlib import Path + +import datasets +import numpy as np +import torch +from datasets import load_dataset + +from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + Backtrackable, + LookAheadError, + LookBackError, + check_version_compatibility, + find_float_index, + get_delta_indices, + is_float_in_list, + item_to_torch, + safe_shard, +) +from lerobot.datasets.video_utils import ( + VideoDecoderCache, + decode_video_frames_torchcodec, +) + + +class StreamingLeRobotDataset(torch.utils.data.IterableDataset): + """LeRobotDataset with streaming capabilities. + + This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed + rather than loaded entirely into memory. This is especially useful for large datasets that may + not fit in memory or when you want to quickly explore a dataset without downloading it completely. + + The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent + items, allowing us to access previous frames for delta timestamps without loading the entire + dataset into memory. + + Example: + Basic usage: + ```python + from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset + + # Create a streaming dataset with delta timestamps + delta_timestamps = { + "observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current + "action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future + } + + dataset = StreamingLeRobotDataset( + repo_id="your-dataset-repo-id", + delta_timestamps=delta_timestamps, + streaming=True, + buffer_size=1000, + ) + + # Iterate over the dataset + for i, item in enumerate(dataset): + print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}") + # item will contain stacked frames according to delta_timestamps + if i >= 10: + break + ``` + """ + + def __init__( + self, + repo_id: str, + root: str | Path | None = None, + episodes: list[int] | None = None, + image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, + tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, + streaming: bool = True, + buffer_size: int = 1000, + max_num_shards: int = 16, + seed: int = 42, + rng: np.random.Generator | None = None, + shuffle: bool = True, + ): + """Initialize a StreamingLeRobotDataset. + + Args: + repo_id (str): This is the repo id that will be used to fetch the dataset. + root (Path | None, optional): Local directory to use for downloading/writing files. + episodes (list[int] | None, optional): If specified, this will only load episodes specified by + their episode_index in this list. + image_transforms (Callable | None, optional): Transform to apply to image data. + tolerance_s (float, optional): Tolerance in seconds for timestamp matching. + revision (str, optional): Git revision id (branch name, tag, or commit hash). + force_cache_sync (bool, optional): Flag to sync and refresh local files first. + streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True. + buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000. + max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16. + seed (int, optional): Reproducibility random seed. + rng (np.random.Generator | None, optional): Random number generator. + shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True. + """ + super().__init__() + self.repo_id = repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id + self.streaming_from_local = root is not None + + self.image_transforms = image_transforms + self.episodes = episodes + self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION + self.seed = seed + self.rng = rng if rng is not None else np.random.default_rng(seed) + self.shuffle = shuffle + + self.streaming = streaming + self.buffer_size = buffer_size + + # We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown) + self.video_decoder_cache = None + + self.root.mkdir(exist_ok=True, parents=True) + + # Load metadata + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + # Check version + check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) + + self.delta_timestamps = None + self.delta_indices = None + + if delta_timestamps is not None: + self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid + self.delta_timestamps = delta_timestamps + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + + self.hf_dataset: datasets.IterableDataset = load_dataset( + self.repo_id if not self.streaming_from_local else str(self.root), + split="train", + streaming=self.streaming, + data_files="data/*/*.parquet", + revision=self.revision, + ) + + self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) + + @property + def num_frames(self): + return self.meta.total_frames + + @property + def num_episodes(self): + return self.meta.total_episodes + + @property + def fps(self): + return self.meta.fps + + @staticmethod + def _iter_random_indices( + rng: np.random.Generator, buffer_size: int, random_batch_size=100 + ) -> Iterator[int]: + while True: + yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size)) + + @staticmethod + def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]: + while True: + yield rng.choice(elements) + + # TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading. + # The current sequential iteration is a bottleneck. A producer-consumer pattern + # could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding) + # in parallel, feeding a queue from which this iterator will yield processed items. + def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: + if self.video_decoder_cache is None: + self.video_decoder_cache = VideoDecoderCache() + + # keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions + rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng + + buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size) + + idx_to_backtrack_dataset = { + idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards)) + for idx in range(self.num_shards) + } + + # This buffer is populated while iterating on the dataset's shards + # the logic is to add 2 levels of randomness: + # (1) sample one shard at random from the ones available, and + # (2) sample one frame from the shard sampled at (1) + frames_buffer = [] + while available_shards := list(idx_to_backtrack_dataset.keys()): + shard_key = next(self._infinite_generator_over_elements(rng, available_shards)) + backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on + + try: + for frame in self.make_frame(backtrack_dataset): + if len(frames_buffer) == self.buffer_size: + i = next(buffer_indices_generator) # samples a element from the buffer + yield frames_buffer[i] + frames_buffer[i] = frame + else: + frames_buffer.append(frame) + break # random shard sampled, switch shard + except ( + RuntimeError, + StopIteration, + ): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7 + del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard + + # Once shards are all exhausted, shuffle the buffer and yield the remaining frames + rng.shuffle(frames_buffer) + yield from frames_buffer + + def _get_window_steps( + self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False + ) -> tuple[int, int]: + if delta_timestamps is None: + return 1, 1 + + if not dynamic_bounds: + # Fix the windows + lookback = LOOKBACK_BACKTRACKTABLE + lookahead = LOOKAHEAD_BACKTRACKTABLE + else: + # Dynamically adjust the windows based on the given delta_timesteps + all_timestamps = sum(delta_timestamps.values(), []) + lookback = min(all_timestamps) * self.fps + lookahead = max(all_timestamps) * self.fps + + # When lookback is >=0 it means no negative timesteps have been provided + lookback = 0 if lookback >= 0 else (lookback * -1) + + return lookback, lookahead + + def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable: + lookback, lookahead = self._get_window_steps(self.delta_timestamps) + return Backtrackable(dataset, history=lookback, lookahead=lookahead) + + def _make_timestamps_from_indices( + self, start_ts: float, indices: dict[str, list[int]] | None = None + ) -> dict[str, list[float]]: + if indices is not None: + return { + key: ( + start_ts + torch.tensor(indices[key]) / self.fps + ).tolist() # NOTE: why not delta_timestamps directly? + for key in self.delta_timestamps + } + else: + return dict.fromkeys(self.meta.video_keys, [start_ts]) + + def _make_padding_camera_frame(self, camera_key: str): + """Variable-shape padding frame for given camera keys, given in (H, W, C)""" + return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1) + + def _get_video_frame_padding_mask( + self, + video_frames: dict[str, torch.Tensor], + query_timestamps: dict[str, list[float]], + original_timestamps: dict[str, list[float]], + ) -> dict[str, torch.BoolTensor]: + padding_mask = {} + + for video_key, timestamps in original_timestamps.items(): + if video_key not in video_frames: + continue # only padding on video keys that are available + frames = [] + mask = [] + padding_frame = self._make_padding_camera_frame(video_key) + for ts in timestamps: + if is_float_in_list(ts, query_timestamps[video_key]): + idx = find_float_index(ts, query_timestamps[video_key]) + frames.append(video_frames[video_key][idx, :]) + mask.append(False) + else: + frames.append(padding_frame) + mask.append(True) + + padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask) + + return padding_mask + + def make_frame( + self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None + ) -> Generator: + """Makes a frame starting from a dataset iterator""" + item = next(dataset_iterator) + item = item_to_torch(item) + + updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features) + + # Get episode index from the item + ep_idx = item["episode_index"] + + # "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps) + current_ts = item["index"] / self.fps + + episode_boundaries_ts = { + key: ( + self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"], + self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"], + ) + for key in self.meta.video_keys + } + + # Apply delta querying logic if necessary + if self.delta_indices is not None: + query_result, padding = self._get_delta_frames(dataset_iterator, item) + updates.append(query_result) + updates.append(padding) + + # Load video frames, when needed + if len(self.meta.video_keys) > 0: + original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) + + # Some timestamps might not result available considering the episode's boundaries + query_timestamps = self._get_query_timestamps( + current_ts, self.delta_indices, episode_boundaries_ts + ) + video_frames = self._query_videos(query_timestamps, ep_idx) + + if self.image_transforms is not None: + image_keys = self.meta.camera_keys + for cam in image_keys: + video_frames[cam] = self.image_transforms(video_frames[cam]) + + updates.append(video_frames) + + if self.delta_indices is not None: + # We always return the same number of frames. Unavailable frames are padded. + padding_mask = self._get_video_frame_padding_mask( + video_frames, query_timestamps, original_timestamps + ) + updates.append(padding_mask) + + result = item.copy() + for update in updates: + result.update(update) + + result["task"] = self.meta.tasks.iloc[item["task_index"]].name + + yield result + + def _get_query_timestamps( + self, + current_ts: float, + query_indices: dict[str, list[int]] | None = None, + episode_boundaries_ts: dict[str, tuple[float, float]] | None = None, + ) -> dict[str, list[float]]: + query_timestamps = {} + keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices) + for key in self.meta.video_keys: + if query_indices is not None and key in query_indices: + timestamps = keys_to_timestamps[key] + # Clamp out timesteps outside of episode boundaries + query_timestamps[key] = torch.clamp( + torch.tensor(timestamps), *episode_boundaries_ts[key] + ).tolist() + + else: + query_timestamps[key] = [current_ts] + + return query_timestamps + + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function + in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a + Segmentation Fault. This probably happens because a memory reference to the video loader is created in + the main process and a subprocess fails to access it. + """ + + item = {} + for video_key, query_ts in query_timestamps.items(): + root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root + video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" + frames = decode_video_frames_torchcodec( + video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache + ) + + item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames + + return item + + def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict): + # TODO(fracapuano): Modularize this function, refactor the code + """Get frames with delta offsets using the backtrackable iterator. + + Args: + current_item (dict): Current item from the iterator. + ep_idx (int): Episode index. + + Returns: + tuple: (query_result, padding) - frames at delta offsets and padding info. + """ + current_episode_idx = current_item["episode_index"] + + # Prepare results + query_result = {} + padding = {} + + for key, delta_indices in self.delta_indices.items(): + if key in self.meta.video_keys: + continue # visual frames are decoded separately + + target_frames = [] + is_pad = [] + + # Create a results dictionary to store frames in processing order, then reconstruct original order for stacking + delta_results = {} + + # Separate and sort deltas by difficulty (easier operations first) + negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...] + positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...] + zero_deltas = [d for d in delta_indices if d == 0] + + # Process zero deltas (current frame) + for delta in zero_deltas: + delta_results[delta] = ( + current_item[key], + False, + ) + + # Process negative deltas in order of increasing difficulty + lookback_failed = False + + last_successful_frame = current_item[key] + + for delta in negative_deltas: + if lookback_failed: + delta_results[delta] = (last_successful_frame, True) + continue + + try: + steps_back = abs(delta) + if dataset_iterator.can_peek_back(steps_back): + past_item = dataset_iterator.peek_back(steps_back) + past_item = item_to_torch(past_item) + + if past_item["episode_index"] == current_episode_idx: + delta_results[delta] = (past_item[key], False) + last_successful_frame = past_item[key] + + else: + raise LookBackError("Retrieved frame is from different episode!") + else: + raise LookBackError("Cannot go back further than the history buffer!") + + except LookBackError: + delta_results[delta] = (last_successful_frame, True) + lookback_failed = True # All subsequent negative deltas will also fail + + # Process positive deltas in order of increasing difficulty + lookahead_failed = False + last_successful_frame = current_item[key] + + for delta in positive_deltas: + if lookahead_failed: + delta_results[delta] = (last_successful_frame, True) + continue + + try: + if dataset_iterator.can_peek_ahead(delta): + future_item = dataset_iterator.peek_ahead(delta) + future_item = item_to_torch(future_item) + + if future_item["episode_index"] == current_episode_idx: + delta_results[delta] = (future_item[key], False) + last_successful_frame = future_item[key] + + else: + raise LookAheadError("Retrieved frame is from different episode!") + else: + raise LookAheadError("Cannot go ahead further than the lookahead buffer!") + + except LookAheadError: + delta_results[delta] = (last_successful_frame, True) + lookahead_failed = True # All subsequent positive deltas will also fail + + # Reconstruct original order for stacking + for delta in delta_indices: + frame, is_padded = delta_results[delta] + + # add batch dimension for stacking + target_frames.append(frame) # frame.unsqueeze(0)) + is_pad.append(is_padded) + + # Stack frames and add to results + if target_frames: + query_result[key] = torch.stack(target_frames) + padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad) + + return query_result, padding + + def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None: + """ + Validate that all keys in delta_timestamps correspond to actual features in the dataset. + + Raises: + ValueError: If any delta timestamp key doesn't correspond to a dataset feature. + """ + if delta_timestamps is None: + return + + # Get all available feature keys from the dataset metadata + available_features = set(self.meta.features.keys()) + + # Get all keys from delta_timestamps + delta_keys = set(delta_timestamps.keys()) + + # Find any keys that don't correspond to features + invalid_keys = delta_keys - available_features + + if invalid_keys: + raise ValueError( + f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. " + f"Available features are: {sorted(available_features)}" + ) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 2b0d95e1..c840d5bc 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -17,10 +17,11 @@ import contextlib import importlib.resources import json import logging -from collections.abc import Iterator +from collections import deque +from collections.abc import Iterable, Iterator from pathlib import Path from pprint import pformat -from typing import Any +from typing import Any, Deque, Generic, TypeVar import datasets import numpy as np @@ -86,6 +87,8 @@ DEFAULT_FEATURES = { "task_index": {"dtype": "int64", "shape": (1,), "names": None}, } +T = TypeVar("T") + def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float: metadata = pq.read_metadata(parquet_path) @@ -776,3 +779,230 @@ def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: """ # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + + +def item_to_torch(item: dict) -> dict: + """Convert all items in a dictionary to PyTorch tensors where appropriate. + + This function is used to convert an item from a streaming dataset to PyTorch tensors. + + Args: + item (dict): Dictionary of items from a dataset. + + Returns: + dict: Dictionary with all tensor-like items converted to torch.Tensor. + """ + for key, val in item.items(): + if isinstance(val, (np.ndarray, list)) and key not in ["task"]: + # Convert numpy arrays and lists to torch tensors + item[key] = torch.tensor(val) + return item + + +def is_float_in_list(target, float_list, threshold=1e-6): + return any(abs(target - x) <= threshold for x in float_list) + + +def find_float_index(target, float_list, threshold=1e-6): + for i, x in enumerate(float_list): + if abs(target - x) <= threshold: + return i + return -1 + + +class LookBackError(Exception): + """ + Exception raised when trying to look back in the history of a Backtrackable object. + """ + + pass + + +class LookAheadError(Exception): + """ + Exception raised when trying to look ahead in the future of a Backtrackable object. + """ + + pass + + +class Backtrackable(Generic[T]): + """ + Wrap any iterator/iterable so you can step back up to `history` items + and look ahead up to `lookahead` items. + + This is useful for streaming datasets where you need to access previous and future items + but can't load the entire dataset into memory. + + Example: + ------- + ```python + ds = load_dataset("c4", "en", streaming=True, split="train") + rev = Backtrackable(ds, history=3, lookahead=2) + + x0 = next(rev) # forward + x1 = next(rev) + x2 = next(rev) + + # Look ahead + x3_peek = rev.peek_ahead(1) # next item without moving cursor + x4_peek = rev.peek_ahead(2) # two items ahead + + # Look back + x1_again = rev.peek_back(1) # previous item without moving cursor + x0_again = rev.peek_back(2) # two items back + + # Move backward + x1_back = rev.prev() # back one step + next(rev) # returns x2, continues forward from where we were + ``` + """ + + __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead") + + def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): + if history < 1: + raise ValueError("history must be >= 1") + if lookahead <= 0: + raise ValueError("lookahead must be > 0") + + self._source: Iterator[T] = iter(iterable) + self._back_buf: Deque[T] = deque(maxlen=history) + self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._cursor: int = 0 + self._history = history + self._lookahead = lookahead + + def __iter__(self) -> "Backtrackable[T]": + return self + + def __next__(self) -> T: + # If we've stepped back, consume from back buffer first + if self._cursor < 0: # -1 means "last item", etc. + self._cursor += 1 + return self._back_buf[self._cursor] + + # If we have items in the ahead buffer, use them first + item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) + + # Add current item to back buffer and reset cursor + self._back_buf.append(item) + self._cursor = 0 + return item + + def prev(self) -> T: + """ + Step one item back in history and return it. + Raises IndexError if already at the oldest buffered item. + """ + if len(self._back_buf) + self._cursor <= 1: + raise LookBackError("At start of history") + + self._cursor -= 1 + return self._back_buf[self._cursor] + + def peek_back(self, n: int = 1) -> T: + """ + Look `n` items back (n=1 == previous item) without moving the cursor. + """ + if n < 0 or n + 1 > len(self._back_buf) + self._cursor: + raise LookBackError("peek_back distance out of range") + + return self._back_buf[self._cursor - (n + 1)] + + def peek_ahead(self, n: int = 1) -> T: + """ + Look `n` items ahead (n=1 == next item) without moving the cursor. + Fills the ahead buffer if necessary. + """ + if n < 1: + raise LookAheadError("peek_ahead distance must be 1 or more") + elif n > self._lookahead: + raise LookAheadError("peek_ahead distance exceeds lookahead limit") + + # Fill ahead buffer if we don't have enough items + while len(self._ahead_buf) < n: + try: + item = next(self._source) + self._ahead_buf.append(item) + + except StopIteration as err: + raise LookAheadError("peek_ahead: not enough items in source") from err + + return self._ahead_buf[n - 1] + + def history(self) -> list[T]: + """ + Return a copy of the buffered history (most recent last). + The list length ≤ `history` argument passed at construction. + """ + if self._cursor == 0: + return list(self._back_buf) + + # When cursor<0, slice so the order remains chronological + return list(self._back_buf)[: self._cursor or None] + + def lookahead_buffer(self) -> list[T]: + """ + Return a copy of the current lookahead buffer. + """ + return list(self._ahead_buf) + + def can_peek_back(self, steps: int = 1) -> bool: + """ + Check if we can go back `steps` items without raising an IndexError. + """ + return steps <= len(self._back_buf) + self._cursor + + def can_peek_ahead(self, steps: int = 1) -> bool: + """ + Check if we can peek ahead `steps` items. + This may involve trying to fill the ahead buffer. + """ + if self._lookahead > 0 and steps > self._lookahead: + return False + + # Try to fill ahead buffer to check if we can peek that far + try: + while len(self._ahead_buf) < steps: + if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: + return False + item = next(self._source) + self._ahead_buf.append(item) + return True + except StopIteration: + return False + + def reset_cursor(self) -> None: + """ + Reset cursor to the most recent position (equivalent to calling next() + until you're back to the latest item). + """ + self._cursor = 0 + + def clear_ahead_buffer(self) -> None: + """ + Clear the ahead buffer, discarding any pre-fetched items. + """ + self._ahead_buf.clear() + + def switch_source_iterable(self, new_source: Iterable[T]) -> None: + """ + Switch the source of the backtrackable to a new iterable, keeping the history. + + This is useful when iterating over a sequence of datasets. The history from the + previous source is kept, but the lookahead buffer is cleared. The cursor is reset + to the present. + """ + self._source = iter(new_source) + self.clear_ahead_buffer() + self.reset_cursor() + + +def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset: + """ + Safe shards the dataset. + """ + shard_idx = min(dataset.num_shards, index + 1) - 1 + + return dataset.shard(num_shards, index=shard_idx) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 9d7df8d6..9da89022 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -21,9 +21,11 @@ import tempfile import warnings from dataclasses import dataclass, field from pathlib import Path +from threading import Lock from typing import Any, ClassVar import av +import fsspec import pyarrow as pa import torch import torchvision @@ -169,15 +171,68 @@ def decode_video_frames_torchvision( return closest_frames +class VideoDecoderCache: + """Thread-safe cache for video decoders to avoid expensive re-initialization.""" + + def __init__(self): + self._cache: dict[str, tuple[Any, Any]] = {} + self._lock = Lock() + + def get_decoder(self, video_path: str): + """Get a cached decoder or create a new one.""" + if importlib.util.find_spec("torchcodec"): + from torchcodec.decoders import VideoDecoder + else: + raise ImportError("torchcodec is required but not available.") + + video_path = str(video_path) + + with self._lock: + if video_path not in self._cache: + file_handle = fsspec.open(video_path).__enter__() + decoder = VideoDecoder(file_handle, seek_mode="approximate") + self._cache[video_path] = (decoder, file_handle) + + return self._cache[video_path][0] + + def clear(self): + """Clear the cache and close file handles.""" + with self._lock: + for _, file_handle in self._cache.values(): + file_handle.close() + self._cache.clear() + + def size(self) -> int: + """Return the number of cached decoders.""" + with self._lock: + return len(self._cache) + + +class FrameTimestampError(ValueError): + """Helper error to indicate the retrieved timestamps exceed the queried ones""" + + pass + + +_default_decoder_cache = VideoDecoderCache() + + def decode_video_frames_torchcodec( video_path: Path | str, timestamps: list[float], tolerance_s: float, - device: str = "cpu", log_loaded_timestamps: bool = False, + decoder_cache: VideoDecoderCache | None = None, ) -> torch.Tensor: """Loads frames associated with the requested timestamps of a video using torchcodec. + Args: + video_path: Path to the video file. + timestamps: List of timestamps to extract frames. + tolerance_s: Allowed deviation in seconds for frame retrieval. + log_loaded_timestamps: Whether to log loaded timestamps. + decoder_cache: Optional decoder cache instance. Uses default if None. + Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors. Note: Video benefits from inter-frame compression. Instead of storing every frame individually, @@ -186,27 +241,24 @@ def decode_video_frames_torchcodec( and all subsequent frames until reaching the requested frame. The number of key frames in a video can be adjusted during encoding to take into account decoding time and video size in bytes. """ + if decoder_cache is None: + decoder_cache = _default_decoder_cache - if importlib.util.find_spec("torchcodec"): - from torchcodec.decoders import VideoDecoder - else: - raise ImportError("torchcodec is required but not available.") + # Use cached decoder instead of creating new one each time + decoder = decoder_cache.get_decoder(str(video_path)) - # initialize video decoder - decoder = VideoDecoder(video_path, device=device, seek_mode="approximate") - loaded_frames = [] loaded_ts = [] + loaded_frames = [] + # get metadata for frame information metadata = decoder.metadata average_fps = metadata.average_fps - # convert timestamps to frame indices frame_indices = [round(ts * average_fps) for ts in timestamps] - # retrieve frames based on indices frames_batch = decoder.get_frames_at(indices=frame_indices) - for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False): + for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True): loaded_frames.append(frame) loaded_ts.append(pts.item()) if log_loaded_timestamps: @@ -237,10 +289,14 @@ def decode_video_frames_torchcodec( if log_loaded_timestamps: logging.info(f"{closest_ts=}") - # convert to float32 in [0,1] range (channel first) - closest_frames = closest_frames.type(torch.float32) / 255 + # convert to float32 in [0,1] range + closest_frames = (closest_frames / 255.0).type(torch.float32) + + if not len(timestamps) == len(closest_frames): + raise FrameTimestampError( + f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}" + ) - assert len(timestamps) == len(closest_frames) return closest_frames diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index ba3db607..398bea90 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -179,10 +179,11 @@ def train(cfg: TrainPipelineConfig): dataset, num_workers=cfg.num_workers, batch_size=cfg.batch_size, - shuffle=shuffle, + shuffle=shuffle and not cfg.dataset.streaming, sampler=sampler, pin_memory=device.type == "cuda", drop_last=False, + prefetch_factor=2, ) dl_iter = cycle(dataloader) @@ -208,6 +209,9 @@ def train(cfg: TrainPipelineConfig): for key in batch: if isinstance(batch[key], torch.Tensor): + if batch[key].dtype != torch.bool: + batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key] + batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") train_tracker, output_dict = update_policy( diff --git a/tests/datasets/test_streaming.py b/tests/datasets/test_streaming.py new file mode 100644 index 00000000..506be3ec --- /dev/null +++ b/tests/datasets/test_streaming.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +import torch + +from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset +from lerobot.datasets.utils import safe_shard +from tests.fixtures.constants import DUMMY_REPO_ID + + +def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]: + """Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices.""" + rng = np.random.default_rng(streaming_ds.seed) + buffer_size = streaming_ds.buffer_size + num_shards = streaming_ds.num_shards + + shards_indices = [] + for shard_idx in range(num_shards): + shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx) + shard_indices = [item["index"] for item in shard] + shards_indices.append(shard_indices) + + shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)} + + buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size) + + frames_buffer = [] + expected_indices = [] + + while shard_iterators: # While there are still available shards + available_shard_keys = list(shard_iterators.keys()) + if not available_shard_keys: + break + + # Call _infinite_generator_over_elements with current available shards (key difference!) + shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys)) + + try: + frame_index = next(shard_iterators[shard_key]) + + if len(frames_buffer) == buffer_size: + i = next(buffer_indices_generator) + expected_indices.append(frames_buffer[i]) + frames_buffer[i] = frame_index + else: + frames_buffer.append(frame_index) + + except StopIteration: + del shard_iterators[shard_key] # Remove exhausted shard + + rng.shuffle(frames_buffer) + expected_indices.extend(frames_buffer) + + return expected_indices + + +def test_single_frame_consistency(tmp_path, lerobot_dataset_factory): + """Test if are correctly accessed""" + ds_num_frames = 400 + ds_num_episodes = 10 + buffer_size = 100 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}" + + ds = lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + ) + + streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size)) + + key_checks = [] + for _ in range(ds_num_frames): + streaming_frame = next(streaming_ds) + frame_idx = streaming_frame["index"] + target_frame = ds[frame_idx] + + for key in streaming_frame: + left = streaming_frame[key] + right = target_frame[key] + + if isinstance(left, str): + check = left == right + + elif isinstance(left, torch.Tensor): + check = torch.allclose(left, right) and left.shape == right.shape + + elif isinstance(left, float): + check = left == right.item() # right is a torch.Tensor + + key_checks.append((key, check)) + + assert all(t[1] for t in key_checks), ( + f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})" + ) + + +@pytest.mark.parametrize( + "shuffle", + [False, True], +) +def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle): + """Test if streamed frames correspond to shuffling operations over in-memory dataset.""" + ds_num_frames = 400 + ds_num_episodes = 10 + buffer_size = 100 + seed = 42 + n_epochs = 3 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}" + + lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + ) + + streaming_ds = StreamingLeRobotDataset( + repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle + ) + + first_epoch_indices = [frame["index"] for frame in streaming_ds] + expected_indices = get_frames_expected_order(streaming_ds) + + assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices" + + expected_indices = get_frames_expected_order(streaming_ds) + for _ in range(n_epochs): + streaming_indices = [frame["index"] for frame in streaming_ds] + frames_match = all( + s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True) + ) + + if shuffle: + assert not frames_match + else: + assert frames_match + + +@pytest.mark.parametrize( + "shuffle", + [False, True], +) +def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle): + """Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards.""" + ds_num_frames = 100 + ds_num_episodes = 10 + buffer_size = 10 + + seed = 42 + n_epochs = 3 + data_file_size_mb = 0.001 + + chunks_size = 1 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}-ciao" + + lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + data_files_size_in_mb=data_file_size_mb, + chunks_size=chunks_size, + ) + + streaming_ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + buffer_size=buffer_size, + seed=seed, + shuffle=shuffle, + max_num_shards=4, + ) + + first_epoch_indices = [frame["index"] for frame in streaming_ds] + expected_indices = get_frames_expected_order(streaming_ds) + + assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices" + + for _ in range(n_epochs): + streaming_indices = [ + frame["index"] for frame in streaming_ds + ] # NOTE: this is the same as first_epoch_indices + frames_match = all( + s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True) + ) + if shuffle: + assert not frames_match + else: + assert frames_match + + +@pytest.mark.parametrize( + "state_deltas, action_deltas", + [ + ([-1, -0.5, -0.20, 0], [0, 1, 2, 3]), + ([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]), + ([-2, -1, -0.5, 0], [0, 1, 2, 3]), + ([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]), + ], +) +def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas): + ds_num_frames = 500 + ds_num_episodes = 10 + buffer_size = 100 + + seed = 42 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}-ciao" + camera_key = "phone" + + delta_timestamps = { + camera_key: state_deltas, + "state": state_deltas, + "action": action_deltas, + } + + ds = lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + delta_timestamps=delta_timestamps, + ) + + streaming_ds = iter( + StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + buffer_size=buffer_size, + seed=seed, + shuffle=False, + delta_timestamps=delta_timestamps, + ) + ) + + for i in range(ds_num_frames): + streaming_frame = next(streaming_ds) + frame_idx = streaming_frame["index"] + target_frame = ds[frame_idx] + + assert set(streaming_frame.keys()) == set(target_frame.keys()), ( + f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}" + ) + + key_checks = [] + for key in streaming_frame: + left = streaming_frame[key] + right = target_frame[key] + + if isinstance(left, str): + check = left == right + + elif isinstance(left, torch.Tensor): + if ( + key not in ds.meta.camera_keys + and "is_pad" not in key + and f"{key}_is_pad" in streaming_frame + ): + # comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting + left = left[~streaming_frame[f"{key}_is_pad"]] + right = right[~target_frame[f"{key}_is_pad"]] + + check = torch.allclose(left, right) and left.shape == right.shape + + key_checks.append((key, check)) + + assert all(t[1] for t in key_checks), ( + f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})" + ) + + +@pytest.mark.parametrize( + "state_deltas, action_deltas", + [ + ([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]), + ([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]), + ([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]), + ([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]), + ], +) +def test_frames_with_delta_consistency_with_shards( + tmp_path, lerobot_dataset_factory, state_deltas, action_deltas +): + ds_num_frames = 100 + ds_num_episodes = 10 + buffer_size = 10 + data_file_size_mb = 0.001 + chunks_size = 1 + + seed = 42 + + local_path = tmp_path / "test" + repo_id = f"{DUMMY_REPO_ID}-ciao" + camera_key = "phone" + + delta_timestamps = { + camera_key: state_deltas, + "state": state_deltas, + "action": action_deltas, + } + + ds = lerobot_dataset_factory( + root=local_path, + repo_id=repo_id, + total_episodes=ds_num_episodes, + total_frames=ds_num_frames, + delta_timestamps=delta_timestamps, + data_files_size_in_mb=data_file_size_mb, + chunks_size=chunks_size, + ) + streaming_ds = StreamingLeRobotDataset( + repo_id=repo_id, + root=local_path, + buffer_size=buffer_size, + seed=seed, + shuffle=False, + delta_timestamps=delta_timestamps, + max_num_shards=4, + ) + + iter(streaming_ds) + + num_shards = 4 + shards_indices = [] + for shard_idx in range(num_shards): + shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards) + shard_indices = [item["index"] for item in shard] + shards_indices.append(shard_indices) + + streaming_ds = iter(streaming_ds) + + for i in range(ds_num_frames): + streaming_frame = next(streaming_ds) + frame_idx = streaming_frame["index"] + target_frame = ds[frame_idx] + + assert set(streaming_frame.keys()) == set(target_frame.keys()), ( + f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}" + ) + + key_checks = [] + for key in streaming_frame: + left = streaming_frame[key] + right = target_frame[key] + + if isinstance(left, str): + check = left == right + + elif isinstance(left, torch.Tensor): + if ( + key not in ds.meta.camera_keys + and "is_pad" not in key + and f"{key}_is_pad" in streaming_frame + ): + # comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting + left = left[~streaming_frame[f"{key}_is_pad"]] + right = right[~target_frame[f"{key}_is_pad"]] + + check = torch.allclose(left, right) and left.shape == right.shape + + elif isinstance(left, float): + check = left == right.item() # right is a torch.Tensor + + key_checks.append((key, check)) + + assert all(t[1] for t in key_checks), ( + f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})" + )