forked from tangger/lerobot
add: support for delta timesteps on streaming mode
This commit is contained in:
@@ -250,8 +250,24 @@ def profile_dataset(
|
||||
stats_file_path = "streaming_dataset_profile.txt"
|
||||
|
||||
print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}")
|
||||
camera_key = "observation.images.cam_right_wrist"
|
||||
fps = 50
|
||||
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / fps for t in range(64)],
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, buffer_size=buffer_size, max_num_shards=max_num_shards, seed=seed
|
||||
repo_id=repo_id,
|
||||
buffer_size=buffer_size,
|
||||
max_num_shards=max_num_shards,
|
||||
seed=seed,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
_time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Generator, Iterator
|
||||
from typing import Callable, Dict, Generator, Iterator, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -11,7 +11,12 @@ from line_profiler import profile
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
Backtrackable,
|
||||
LookAheadError,
|
||||
LookBackError,
|
||||
check_delta_timestamps,
|
||||
check_version_compatibility,
|
||||
get_delta_indices,
|
||||
item_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
@@ -65,6 +70,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
@@ -123,9 +129,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
if delta_timestamps is not None:
|
||||
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = self.load_hf_dataset()
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
|
||||
max_backward_steps, max_forward_steps = self._get_window_steps()
|
||||
self.backtrackable_dataset: Backtrackable = Backtrackable(
|
||||
self.hf_dataset, history=max_backward_steps, lookahead=max_forward_steps
|
||||
)
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return self.meta.fps
|
||||
@@ -148,20 +163,42 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
|
||||
return dataset
|
||||
|
||||
def _get_window_steps(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns how many steps backward (& forward) should the backtrackable iterator maintain,
|
||||
based on the input delta_timestamps.
|
||||
"""
|
||||
max_backward_steps = 1
|
||||
max_forward_steps = 1
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Calculate maximum backward steps needed (i.e., history size)
|
||||
for delta_idx in self.delta_indices.values():
|
||||
min_delta = min(delta_idx)
|
||||
max_delta = max(delta_idx)
|
||||
if min_delta < 0:
|
||||
max_backward_steps = max(max_backward_steps, abs(min_delta))
|
||||
if max_delta > 0:
|
||||
max_forward_steps = max(max_forward_steps, max_delta)
|
||||
|
||||
return max_backward_steps, max_forward_steps
|
||||
|
||||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
frames_buffer = []
|
||||
idx_to_iterable_dataset = {
|
||||
idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
|
||||
idx_to_backtracktable_dataset = {
|
||||
idx: self._make_backtrackable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
frames_buffer = []
|
||||
try:
|
||||
while available_shards := list(idx_to_iterable_dataset.keys()):
|
||||
while available_shards := list(idx_to_backtracktable_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(available_shards))
|
||||
dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on
|
||||
dataset = idx_to_backtracktable_dataset[shard_key] # selects which shard to iterate on
|
||||
for frame in self.make_frame(dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
@@ -175,18 +212,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
|
||||
# Remove exhausted shard
|
||||
del idx_to_iterable_dataset[shard_key]
|
||||
del idx_to_backtracktable_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
self.rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def _make_iterable_dataset(self, dataset: datasets.IterableDataset) -> Iterator:
|
||||
return iter(dataset)
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
history, lookahead = self._get_window_steps()
|
||||
return Backtrackable(dataset, history=history, lookahead=lookahead)
|
||||
|
||||
@profile
|
||||
def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator:
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
@@ -194,16 +231,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
item = {**item, **query_result, **padding}
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"]
|
||||
query_timestamps = self._get_query_timestamps(current_ts, None)
|
||||
query_indices = self.delta_indices if self.delta_indices is not None else None
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"]
|
||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
||||
item["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield item
|
||||
|
||||
@@ -215,8 +256,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
query_timestamps = {}
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
|
||||
# never query for negative timestamps!
|
||||
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
@@ -239,14 +281,145 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
|
||||
return item
|
||||
|
||||
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
||||
if dataset_iterator.can_go_ahead(n_steps):
|
||||
return dataset_iterator.peek_ahead(n_steps)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
||||
if dataset_iterator.can_go_back(n_steps):
|
||||
return dataset_iterator.peek_back(n_steps)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
current_item (dict): Current item from the iterator.
|
||||
ep_idx (int): Episode index.
|
||||
|
||||
Returns:
|
||||
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||
"""
|
||||
current_episode_idx = current_item["episode_index"]
|
||||
|
||||
# Prepare results
|
||||
query_result = {}
|
||||
padding = {}
|
||||
|
||||
for key, delta_indices in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
|
||||
for delta in delta_indices:
|
||||
if delta == 0:
|
||||
# Current frame
|
||||
target_frames.append(current_item[key])
|
||||
is_pad.append(False)
|
||||
|
||||
elif delta < 0:
|
||||
# Past frame. Use backtrackable iterator, looking back delta steps
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
# Check if it's from the same episode
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
target_frames.append(past_item[key])
|
||||
is_pad.append(False)
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
target_frames.append(torch.zeros_like(current_item[key]))
|
||||
is_pad.append(True)
|
||||
|
||||
elif delta > 0:
|
||||
# Future frame - read ahead from the iterator
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
# Check if it's from the same episode
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
target_frames.append(future_item[key])
|
||||
is_pad.append(False)
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
target_frames.append(torch.zeros_like(current_item[key]))
|
||||
is_pad.append(True)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
query_result[key] = torch.stack(target_frames)
|
||||
padding[f"{key}.pad_masking"] = torch.BoolTensor(is_pad)
|
||||
|
||||
return query_result, padding
|
||||
|
||||
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||
"""
|
||||
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
||||
"""
|
||||
if delta_timestamps is None:
|
||||
return
|
||||
|
||||
# Get all available feature keys from the dataset metadata
|
||||
available_features = set(self.meta.features.keys())
|
||||
|
||||
# Get all keys from delta_timestamps
|
||||
delta_keys = set(delta_timestamps.keys())
|
||||
|
||||
# Find any keys that don't correspond to features
|
||||
invalid_keys = delta_keys - available_features
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
||||
f"Available features are: {sorted(available_features)}"
|
||||
)
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
from tqdm import tqdm
|
||||
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
dataset = StreamingLeRobotDataset(repo_id)
|
||||
|
||||
for i, frame in enumerate(dataset):
|
||||
camera_key = "observation.images.cam_right_wrist"
|
||||
fps = 50
|
||||
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / fps for t in range(64)],
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
for i, frame in tqdm(enumerate(dataset)):
|
||||
print(frame)
|
||||
|
||||
if i > 10: # only stream first 10 frames
|
||||
if i > 1000: # only stream first 10 frames
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user