add: support for delta timesteps on streaming mode

This commit is contained in:
fracapuano
2025-05-29 23:11:28 +02:00
parent e405f37b95
commit 90969e57ff
2 changed files with 214 additions and 25 deletions

View File

@@ -250,8 +250,24 @@ def profile_dataset(
stats_file_path = "streaming_dataset_profile.txt"
print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}")
camera_key = "observation.images.cam_right_wrist"
fps = 50
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / fps for t in range(64)],
}
dataset = StreamingLeRobotDataset(
repo_id=repo_id, buffer_size=buffer_size, max_num_shards=max_num_shards, seed=seed
repo_id=repo_id,
buffer_size=buffer_size,
max_num_shards=max_num_shards,
seed=seed,
delta_timestamps=delta_timestamps,
)
_time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path)

View File

@@ -1,6 +1,6 @@
import random
from pathlib import Path
from typing import Callable, Dict, Generator, Iterator
from typing import Callable, Dict, Generator, Iterator, Tuple
import datasets
import numpy as np
@@ -11,7 +11,12 @@ from line_profiler import profile
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
Backtrackable,
LookAheadError,
LookBackError,
check_delta_timestamps,
check_version_compatibility,
get_delta_indices,
item_to_torch,
)
from lerobot.common.datasets.video_utils import (
@@ -65,6 +70,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
@@ -123,9 +129,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
self.hf_dataset = self.load_hf_dataset()
if delta_timestamps is not None:
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
self.delta_timestamps = delta_timestamps
self.hf_dataset: datasets.IterableDataset = self.load_hf_dataset()
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
max_backward_steps, max_forward_steps = self._get_window_steps()
self.backtrackable_dataset: Backtrackable = Backtrackable(
self.hf_dataset, history=max_backward_steps, lookahead=max_forward_steps
)
@property
def fps(self):
return self.meta.fps
@@ -148,20 +163,42 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
return dataset
def _get_window_steps(self) -> Tuple[int, int]:
"""
Returns how many steps backward (& forward) should the backtrackable iterator maintain,
based on the input delta_timestamps.
"""
max_backward_steps = 1
max_forward_steps = 1
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Calculate maximum backward steps needed (i.e., history size)
for delta_idx in self.delta_indices.values():
min_delta = min(delta_idx)
max_delta = max(delta_idx)
if min_delta < 0:
max_backward_steps = max(max_backward_steps, abs(min_delta))
if max_delta > 0:
max_forward_steps = max(max_forward_steps, max_delta)
return max_backward_steps, max_forward_steps
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
# This buffer is populated while iterating on the dataset's shards
frames_buffer = []
idx_to_iterable_dataset = {
idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
idx_to_backtracktable_dataset = {
idx: self._make_backtrackable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
for idx in range(self.num_shards)
}
# This buffer is populated while iterating on the dataset's shards
frames_buffer = []
try:
while available_shards := list(idx_to_iterable_dataset.keys()):
while available_shards := list(idx_to_backtracktable_dataset.keys()):
shard_key = next(self._infinite_generator_over_elements(available_shards))
dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on
dataset = idx_to_backtracktable_dataset[shard_key] # selects which shard to iterate on
for frame in self.make_frame(dataset):
if len(frames_buffer) == self.buffer_size:
i = next(buffer_indices_generator)
@@ -175,18 +212,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
RuntimeError,
StopIteration,
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
# Remove exhausted shard
del idx_to_iterable_dataset[shard_key]
del idx_to_backtracktable_dataset[shard_key] # Remove exhausted shard, onto another shard
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
self.rng.shuffle(frames_buffer)
yield from frames_buffer
def _make_iterable_dataset(self, dataset: datasets.IterableDataset) -> Iterator:
return iter(dataset)
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
history, lookahead = self._get_window_steps()
return Backtrackable(dataset, history=history, lookahead=lookahead)
@profile
def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator:
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
"""Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator)
item = item_to_torch(item)
@@ -194,16 +231,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
# Get episode index from the item
ep_idx = item["episode_index"]
# Apply delta querying logic if necessary
if self.delta_indices is not None:
query_result, padding = self._get_delta_frames(dataset_iterator, item)
item = {**item, **query_result, **padding}
# Load video frames, when needed
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"]
query_timestamps = self._get_query_timestamps(current_ts, None)
query_indices = self.delta_indices if self.delta_indices is not None else None
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
# Add task as a string
task_idx = item["task_index"]
item["task"] = self.meta.tasks.iloc[task_idx].name
item["task"] = self.meta.tasks.iloc[item["task_index"]].name
yield item
@@ -215,8 +256,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
query_timestamps = {}
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
# never query for negative timestamps!
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
else:
query_timestamps[key] = [current_ts]
@@ -239,14 +281,145 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return item
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
if dataset_iterator.can_go_ahead(n_steps):
return dataset_iterator.peek_ahead(n_steps)
else:
pass
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
if dataset_iterator.can_go_back(n_steps):
return dataset_iterator.peek_back(n_steps)
else:
pass
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
"""Get frames with delta offsets using the backtrackable iterator.
Args:
current_item (dict): Current item from the iterator.
ep_idx (int): Episode index.
Returns:
tuple: (query_result, padding) - frames at delta offsets and padding info.
"""
current_episode_idx = current_item["episode_index"]
# Prepare results
query_result = {}
padding = {}
for key, delta_indices in self.delta_indices.items():
if key in self.meta.video_keys:
continue # visual frames are decoded separately
target_frames = []
is_pad = []
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
for delta in delta_indices:
if delta == 0:
# Current frame
target_frames.append(current_item[key])
is_pad.append(False)
elif delta < 0:
# Past frame. Use backtrackable iterator, looking back delta steps
try:
steps_back = abs(delta)
if dataset_iterator.can_peek_back(steps_back):
past_item = dataset_iterator.peek_back(steps_back)
past_item = item_to_torch(past_item)
# Check if it's from the same episode
if past_item["episode_index"] == current_episode_idx:
target_frames.append(past_item[key])
is_pad.append(False)
else:
raise LookBackError("Retrieved frame is from different episode!")
else:
raise LookBackError("Cannot go back further than the history buffer!")
except LookBackError:
target_frames.append(torch.zeros_like(current_item[key]))
is_pad.append(True)
elif delta > 0:
# Future frame - read ahead from the iterator
try:
if dataset_iterator.can_peek_ahead(delta):
future_item = dataset_iterator.peek_ahead(delta)
future_item = item_to_torch(future_item)
# Check if it's from the same episode
if future_item["episode_index"] == current_episode_idx:
target_frames.append(future_item[key])
is_pad.append(False)
else:
raise LookAheadError("Retrieved frame is from different episode!")
else:
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
except LookAheadError:
target_frames.append(torch.zeros_like(current_item[key]))
is_pad.append(True)
# Stack frames and add to results
if target_frames:
query_result[key] = torch.stack(target_frames)
padding[f"{key}.pad_masking"] = torch.BoolTensor(is_pad)
return query_result, padding
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
"""
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
Raises:
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
"""
if delta_timestamps is None:
return
# Get all available feature keys from the dataset metadata
available_features = set(self.meta.features.keys())
# Get all keys from delta_timestamps
delta_keys = set(delta_timestamps.keys())
# Find any keys that don't correspond to features
invalid_keys = delta_keys - available_features
if invalid_keys:
raise ValueError(
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
f"Available features are: {sorted(available_features)}"
)
# Example usage
if __name__ == "__main__":
from tqdm import tqdm
repo_id = "lerobot/aloha_mobile_cabinet"
dataset = StreamingLeRobotDataset(repo_id)
for i, frame in enumerate(dataset):
camera_key = "observation.images.cam_right_wrist"
fps = 50
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
camera_key: [-1, -0.5, -0.20, 0],
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / fps for t in range(64)],
}
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
for i, frame in tqdm(enumerate(dataset)):
print(frame)
if i > 10: # only stream first 10 frames
if i > 1000: # only stream first 10 frames
break