add: support for delta timesteps on streaming mode

This commit is contained in:
fracapuano
2025-05-29 23:11:28 +02:00
parent e405f37b95
commit 90969e57ff
2 changed files with 214 additions and 25 deletions

View File

@@ -250,8 +250,24 @@ def profile_dataset(
stats_file_path = "streaming_dataset_profile.txt" stats_file_path = "streaming_dataset_profile.txt"
print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}") print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}")
camera_key = "observation.images.cam_right_wrist"
fps = 50
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / fps for t in range(64)],
}
dataset = StreamingLeRobotDataset( dataset = StreamingLeRobotDataset(
repo_id=repo_id, buffer_size=buffer_size, max_num_shards=max_num_shards, seed=seed repo_id=repo_id,
buffer_size=buffer_size,
max_num_shards=max_num_shards,
seed=seed,
delta_timestamps=delta_timestamps,
) )
_time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path) _time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path)

View File

@@ -1,6 +1,6 @@
import random import random
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Generator, Iterator from typing import Callable, Dict, Generator, Iterator, Tuple
import datasets import datasets
import numpy as np import numpy as np
@@ -11,7 +11,12 @@ from line_profiler import profile
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
Backtrackable,
LookAheadError,
LookBackError,
check_delta_timestamps,
check_version_compatibility, check_version_compatibility,
get_delta_indices,
item_to_torch, item_to_torch,
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
@@ -65,6 +70,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
root: str | Path | None = None, root: str | Path | None = None,
episodes: list[int] | None = None, episodes: list[int] | None = None,
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
revision: str | None = None, revision: str | None = None,
force_cache_sync: bool = False, force_cache_sync: bool = False,
@@ -123,9 +129,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Check version # Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
self.hf_dataset = self.load_hf_dataset() if delta_timestamps is not None:
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
self.delta_timestamps = delta_timestamps
self.hf_dataset: datasets.IterableDataset = self.load_hf_dataset()
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards) self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
max_backward_steps, max_forward_steps = self._get_window_steps()
self.backtrackable_dataset: Backtrackable = Backtrackable(
self.hf_dataset, history=max_backward_steps, lookahead=max_forward_steps
)
@property @property
def fps(self): def fps(self):
return self.meta.fps return self.meta.fps
@@ -148,20 +163,42 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub # TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
return dataset return dataset
def _get_window_steps(self) -> Tuple[int, int]:
"""
Returns how many steps backward (& forward) should the backtrackable iterator maintain,
based on the input delta_timestamps.
"""
max_backward_steps = 1
max_forward_steps = 1
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Calculate maximum backward steps needed (i.e., history size)
for delta_idx in self.delta_indices.values():
min_delta = min(delta_idx)
max_delta = max(delta_idx)
if min_delta < 0:
max_backward_steps = max(max_backward_steps, abs(min_delta))
if max_delta > 0:
max_forward_steps = max(max_forward_steps, max_delta)
return max_backward_steps, max_forward_steps
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size) buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
idx_to_backtracktable_dataset = {
# This buffer is populated while iterating on the dataset's shards idx: self._make_backtrackable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
frames_buffer = []
idx_to_iterable_dataset = {
idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
for idx in range(self.num_shards) for idx in range(self.num_shards)
} }
# This buffer is populated while iterating on the dataset's shards
frames_buffer = []
try: try:
while available_shards := list(idx_to_iterable_dataset.keys()): while available_shards := list(idx_to_backtracktable_dataset.keys()):
shard_key = next(self._infinite_generator_over_elements(available_shards)) shard_key = next(self._infinite_generator_over_elements(available_shards))
dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on dataset = idx_to_backtracktable_dataset[shard_key] # selects which shard to iterate on
for frame in self.make_frame(dataset): for frame in self.make_frame(dataset):
if len(frames_buffer) == self.buffer_size: if len(frames_buffer) == self.buffer_size:
i = next(buffer_indices_generator) i = next(buffer_indices_generator)
@@ -175,18 +212,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
RuntimeError, RuntimeError,
StopIteration, StopIteration,
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7 ): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
# Remove exhausted shard del idx_to_backtracktable_dataset[shard_key] # Remove exhausted shard, onto another shard
del idx_to_iterable_dataset[shard_key]
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames # Once shards are all exhausted, shuffle the buffer and yield the remaining frames
self.rng.shuffle(frames_buffer) self.rng.shuffle(frames_buffer)
yield from frames_buffer yield from frames_buffer
def _make_iterable_dataset(self, dataset: datasets.IterableDataset) -> Iterator: def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
return iter(dataset) history, lookahead = self._get_window_steps()
return Backtrackable(dataset, history=history, lookahead=lookahead)
@profile @profile
def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator: def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
"""Makes a frame starting from a dataset iterator""" """Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator) item = next(dataset_iterator)
item = item_to_torch(item) item = item_to_torch(item)
@@ -194,16 +231,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Get episode index from the item # Get episode index from the item
ep_idx = item["episode_index"] ep_idx = item["episode_index"]
# Apply delta querying logic if necessary
if self.delta_indices is not None:
query_result, padding = self._get_delta_frames(dataset_iterator, item)
item = {**item, **query_result, **padding}
# Load video frames, when needed # Load video frames, when needed
if len(self.meta.video_keys) > 0: if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"] current_ts = item["timestamp"]
query_timestamps = self._get_query_timestamps(current_ts, None) query_indices = self.delta_indices if self.delta_indices is not None else None
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx) video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item} item = {**video_frames, **item}
# Add task as a string item["task"] = self.meta.tasks.iloc[item["task_index"]].name
task_idx = item["task_index"]
item["task"] = self.meta.tasks.iloc[task_idx].name
yield item yield item
@@ -215,8 +256,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
query_timestamps = {} query_timestamps = {}
for key in self.meta.video_keys: for key in self.meta.video_keys:
if query_indices is not None and key in query_indices: if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"] timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
query_timestamps[key] = torch.stack(timestamps).tolist() # never query for negative timestamps!
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
else: else:
query_timestamps[key] = [current_ts] query_timestamps[key] = [current_ts]
@@ -239,14 +281,145 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return item return item
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
if dataset_iterator.can_go_ahead(n_steps):
return dataset_iterator.peek_ahead(n_steps)
else:
pass
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
if dataset_iterator.can_go_back(n_steps):
return dataset_iterator.peek_back(n_steps)
else:
pass
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
"""Get frames with delta offsets using the backtrackable iterator.
Args:
current_item (dict): Current item from the iterator.
ep_idx (int): Episode index.
Returns:
tuple: (query_result, padding) - frames at delta offsets and padding info.
"""
current_episode_idx = current_item["episode_index"]
# Prepare results
query_result = {}
padding = {}
for key, delta_indices in self.delta_indices.items():
if key in self.meta.video_keys:
continue # visual frames are decoded separately
target_frames = []
is_pad = []
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
for delta in delta_indices:
if delta == 0:
# Current frame
target_frames.append(current_item[key])
is_pad.append(False)
elif delta < 0:
# Past frame. Use backtrackable iterator, looking back delta steps
try:
steps_back = abs(delta)
if dataset_iterator.can_peek_back(steps_back):
past_item = dataset_iterator.peek_back(steps_back)
past_item = item_to_torch(past_item)
# Check if it's from the same episode
if past_item["episode_index"] == current_episode_idx:
target_frames.append(past_item[key])
is_pad.append(False)
else:
raise LookBackError("Retrieved frame is from different episode!")
else:
raise LookBackError("Cannot go back further than the history buffer!")
except LookBackError:
target_frames.append(torch.zeros_like(current_item[key]))
is_pad.append(True)
elif delta > 0:
# Future frame - read ahead from the iterator
try:
if dataset_iterator.can_peek_ahead(delta):
future_item = dataset_iterator.peek_ahead(delta)
future_item = item_to_torch(future_item)
# Check if it's from the same episode
if future_item["episode_index"] == current_episode_idx:
target_frames.append(future_item[key])
is_pad.append(False)
else:
raise LookAheadError("Retrieved frame is from different episode!")
else:
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
except LookAheadError:
target_frames.append(torch.zeros_like(current_item[key]))
is_pad.append(True)
# Stack frames and add to results
if target_frames:
query_result[key] = torch.stack(target_frames)
padding[f"{key}.pad_masking"] = torch.BoolTensor(is_pad)
return query_result, padding
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
"""
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
Raises:
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
"""
if delta_timestamps is None:
return
# Get all available feature keys from the dataset metadata
available_features = set(self.meta.features.keys())
# Get all keys from delta_timestamps
delta_keys = set(delta_timestamps.keys())
# Find any keys that don't correspond to features
invalid_keys = delta_keys - available_features
if invalid_keys:
raise ValueError(
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
f"Available features are: {sorted(available_features)}"
)
# Example usage # Example usage
if __name__ == "__main__": if __name__ == "__main__":
from tqdm import tqdm
repo_id = "lerobot/aloha_mobile_cabinet" repo_id = "lerobot/aloha_mobile_cabinet"
dataset = StreamingLeRobotDataset(repo_id)
for i, frame in enumerate(dataset): camera_key = "observation.images.cam_right_wrist"
fps = 50
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / fps for t in range(64)],
}
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
for i, frame in tqdm(enumerate(dataset)):
print(frame) print(frame)
if i > 1000: # only stream first 10 frames
if i > 10: # only stream first 10 frames
break break