diff --git a/lerobot/common/datasets/profile_streaming_dataset.py b/lerobot/common/datasets/profile_streaming_dataset.py index aefd59a99..99d0cc8f5 100644 --- a/lerobot/common/datasets/profile_streaming_dataset.py +++ b/lerobot/common/datasets/profile_streaming_dataset.py @@ -250,8 +250,24 @@ def profile_dataset( stats_file_path = "streaming_dataset_profile.txt" print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}") + camera_key = "observation.images.cam_right_wrist" + fps = 50 + + delta_timestamps = { + # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame + camera_key: [-1, -0.5, -0.20, 0], + # loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame + "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0], + # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future + "action": [t / fps for t in range(64)], + } + dataset = StreamingLeRobotDataset( - repo_id=repo_id, buffer_size=buffer_size, max_num_shards=max_num_shards, seed=seed + repo_id=repo_id, + buffer_size=buffer_size, + max_num_shards=max_num_shards, + seed=seed, + delta_timestamps=delta_timestamps, ) _time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path) diff --git a/lerobot/common/datasets/streaming_dataset.py b/lerobot/common/datasets/streaming_dataset.py index 0c322c3ce..743579130 100644 --- a/lerobot/common/datasets/streaming_dataset.py +++ b/lerobot/common/datasets/streaming_dataset.py @@ -1,6 +1,6 @@ import random from pathlib import Path -from typing import Callable, Dict, Generator, Iterator +from typing import Callable, Dict, Generator, Iterator, Tuple import datasets import numpy as np @@ -11,7 +11,12 @@ from line_profiler import profile from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( + Backtrackable, + LookAheadError, + LookBackError, + check_delta_timestamps, check_version_compatibility, + get_delta_indices, item_to_torch, ) from lerobot.common.datasets.video_utils import ( @@ -65,6 +70,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): root: str | Path | None = None, episodes: list[int] | None = None, image_transforms: Callable | None = None, + delta_timestamps: dict[list[float]] | None = None, tolerance_s: float = 1e-4, revision: str | None = None, force_cache_sync: bool = False, @@ -123,9 +129,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) - self.hf_dataset = self.load_hf_dataset() + if delta_timestamps is not None: + self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid + self.delta_timestamps = delta_timestamps + + self.hf_dataset: datasets.IterableDataset = self.load_hf_dataset() self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) + max_backward_steps, max_forward_steps = self._get_window_steps() + self.backtrackable_dataset: Backtrackable = Backtrackable( + self.hf_dataset, history=max_backward_steps, lookahead=max_forward_steps + ) + @property def fps(self): return self.meta.fps @@ -148,20 +163,42 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub return dataset + def _get_window_steps(self) -> Tuple[int, int]: + """ + Returns how many steps backward (& forward) should the backtrackable iterator maintain, + based on the input delta_timestamps. + """ + max_backward_steps = 1 + max_forward_steps = 1 + + if self.delta_timestamps is not None: + check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) + self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + + # Calculate maximum backward steps needed (i.e., history size) + for delta_idx in self.delta_indices.values(): + min_delta = min(delta_idx) + max_delta = max(delta_idx) + if min_delta < 0: + max_backward_steps = max(max_backward_steps, abs(min_delta)) + if max_delta > 0: + max_forward_steps = max(max_forward_steps, max_delta) + + return max_backward_steps, max_forward_steps + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size) - - # This buffer is populated while iterating on the dataset's shards - frames_buffer = [] - idx_to_iterable_dataset = { - idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx)) + idx_to_backtracktable_dataset = { + idx: self._make_backtrackable_dataset(self.hf_dataset.shard(self.num_shards, index=idx)) for idx in range(self.num_shards) } + # This buffer is populated while iterating on the dataset's shards + frames_buffer = [] try: - while available_shards := list(idx_to_iterable_dataset.keys()): + while available_shards := list(idx_to_backtracktable_dataset.keys()): shard_key = next(self._infinite_generator_over_elements(available_shards)) - dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on + dataset = idx_to_backtracktable_dataset[shard_key] # selects which shard to iterate on for frame in self.make_frame(dataset): if len(frames_buffer) == self.buffer_size: i = next(buffer_indices_generator) @@ -175,18 +212,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): RuntimeError, StopIteration, ): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7 - # Remove exhausted shard - del idx_to_iterable_dataset[shard_key] + del idx_to_backtracktable_dataset[shard_key] # Remove exhausted shard, onto another shard # Once shards are all exhausted, shuffle the buffer and yield the remaining frames self.rng.shuffle(frames_buffer) yield from frames_buffer - def _make_iterable_dataset(self, dataset: datasets.IterableDataset) -> Iterator: - return iter(dataset) + def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable: + history, lookahead = self._get_window_steps() + return Backtrackable(dataset, history=history, lookahead=lookahead) @profile - def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator: + def make_frame(self, dataset_iterator: Backtrackable) -> Generator: """Makes a frame starting from a dataset iterator""" item = next(dataset_iterator) item = item_to_torch(item) @@ -194,16 +231,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): # Get episode index from the item ep_idx = item["episode_index"] + # Apply delta querying logic if necessary + if self.delta_indices is not None: + query_result, padding = self._get_delta_frames(dataset_iterator, item) + item = {**item, **query_result, **padding} + # Load video frames, when needed if len(self.meta.video_keys) > 0: current_ts = item["timestamp"] - query_timestamps = self._get_query_timestamps(current_ts, None) + query_indices = self.delta_indices if self.delta_indices is not None else None + query_timestamps = self._get_query_timestamps(current_ts, query_indices) video_frames = self._query_videos(query_timestamps, ep_idx) item = {**video_frames, **item} - # Add task as a string - task_idx = item["task_index"] - item["task"] = self.meta.tasks.iloc[task_idx].name + item["task"] = self.meta.tasks.iloc[item["task_index"]].name yield item @@ -215,8 +256,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): query_timestamps = {} for key in self.meta.video_keys: if query_indices is not None and key in query_indices: - timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] - query_timestamps[key] = torch.stack(timestamps).tolist() + timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps + # never query for negative timestamps! + query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist())) else: query_timestamps[key] = [current_ts] @@ -239,14 +281,145 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): return item + def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int): + if dataset_iterator.can_go_ahead(n_steps): + return dataset_iterator.peek_ahead(n_steps) + else: + pass + + def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int): + if dataset_iterator.can_go_back(n_steps): + return dataset_iterator.peek_back(n_steps) + else: + pass + + def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict): + """Get frames with delta offsets using the backtrackable iterator. + + Args: + current_item (dict): Current item from the iterator. + ep_idx (int): Episode index. + + Returns: + tuple: (query_result, padding) - frames at delta offsets and padding info. + """ + current_episode_idx = current_item["episode_index"] + + # Prepare results + query_result = {} + padding = {} + + for key, delta_indices in self.delta_indices.items(): + if key in self.meta.video_keys: + continue # visual frames are decoded separately + + target_frames = [] + is_pad = [] + + # NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error? + for delta in delta_indices: + if delta == 0: + # Current frame + target_frames.append(current_item[key]) + is_pad.append(False) + + elif delta < 0: + # Past frame. Use backtrackable iterator, looking back delta steps + try: + steps_back = abs(delta) + if dataset_iterator.can_peek_back(steps_back): + past_item = dataset_iterator.peek_back(steps_back) + past_item = item_to_torch(past_item) + + # Check if it's from the same episode + if past_item["episode_index"] == current_episode_idx: + target_frames.append(past_item[key]) + is_pad.append(False) + + else: + raise LookBackError("Retrieved frame is from different episode!") + else: + raise LookBackError("Cannot go back further than the history buffer!") + + except LookBackError: + target_frames.append(torch.zeros_like(current_item[key])) + is_pad.append(True) + + elif delta > 0: + # Future frame - read ahead from the iterator + try: + if dataset_iterator.can_peek_ahead(delta): + future_item = dataset_iterator.peek_ahead(delta) + future_item = item_to_torch(future_item) + + # Check if it's from the same episode + if future_item["episode_index"] == current_episode_idx: + target_frames.append(future_item[key]) + is_pad.append(False) + + else: + raise LookAheadError("Retrieved frame is from different episode!") + else: + raise LookAheadError("Cannot go ahead further than the lookahead buffer!") + + except LookAheadError: + target_frames.append(torch.zeros_like(current_item[key])) + is_pad.append(True) + + # Stack frames and add to results + if target_frames: + query_result[key] = torch.stack(target_frames) + padding[f"{key}.pad_masking"] = torch.BoolTensor(is_pad) + + return query_result, padding + + def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None: + """ + Validate that all keys in delta_timestamps correspond to actual features in the dataset. + + Raises: + ValueError: If any delta timestamp key doesn't correspond to a dataset feature. + """ + if delta_timestamps is None: + return + + # Get all available feature keys from the dataset metadata + available_features = set(self.meta.features.keys()) + + # Get all keys from delta_timestamps + delta_keys = set(delta_timestamps.keys()) + + # Find any keys that don't correspond to features + invalid_keys = delta_keys - available_features + + if invalid_keys: + raise ValueError( + f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. " + f"Available features are: {sorted(available_features)}" + ) + # Example usage if __name__ == "__main__": + from tqdm import tqdm + repo_id = "lerobot/aloha_mobile_cabinet" - dataset = StreamingLeRobotDataset(repo_id) - for i, frame in enumerate(dataset): + camera_key = "observation.images.cam_right_wrist" + fps = 50 + + delta_timestamps = { + # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame + camera_key: [-1, -0.5, -0.20, 0], + # loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame + "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0], + # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future + "action": [t / fps for t in range(64)], + } + + dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps) + + for i, frame in tqdm(enumerate(dataset)): print(frame) - - if i > 10: # only stream first 10 frames + if i > 1000: # only stream first 10 frames break