forked from tangger/lerobot
add: support for delta timesteps on streaming mode
This commit is contained in:
@@ -250,8 +250,24 @@ def profile_dataset(
|
|||||||
stats_file_path = "streaming_dataset_profile.txt"
|
stats_file_path = "streaming_dataset_profile.txt"
|
||||||
|
|
||||||
print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}")
|
print(f"Creating dataset from {repo_id} with buffer_size={buffer_size}, max_num_shards={max_num_shards}")
|
||||||
|
camera_key = "observation.images.cam_right_wrist"
|
||||||
|
fps = 50
|
||||||
|
|
||||||
|
delta_timestamps = {
|
||||||
|
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||||
|
camera_key: [-1, -0.5, -0.20, 0],
|
||||||
|
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||||
|
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||||
|
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||||
|
"action": [t / fps for t in range(64)],
|
||||||
|
}
|
||||||
|
|
||||||
dataset = StreamingLeRobotDataset(
|
dataset = StreamingLeRobotDataset(
|
||||||
repo_id=repo_id, buffer_size=buffer_size, max_num_shards=max_num_shards, seed=seed
|
repo_id=repo_id,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
max_num_shards=max_num_shards,
|
||||||
|
seed=seed,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
_time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path)
|
_time_iterations(dataset, num_samples, num_runs, warmup_iters, stats_file_path)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Generator, Iterator
|
from typing import Callable, Dict, Generator, Iterator, Tuple
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,7 +11,12 @@ from line_profiler import profile
|
|||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
Backtrackable,
|
||||||
|
LookAheadError,
|
||||||
|
LookBackError,
|
||||||
|
check_delta_timestamps,
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
|
get_delta_indices,
|
||||||
item_to_torch,
|
item_to_torch,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import (
|
from lerobot.common.datasets.video_utils import (
|
||||||
@@ -65,6 +70,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
episodes: list[int] | None = None,
|
episodes: list[int] | None = None,
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
@@ -123,9 +129,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
# Check version
|
# Check version
|
||||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||||
|
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
if delta_timestamps is not None:
|
||||||
|
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
||||||
|
self.delta_timestamps = delta_timestamps
|
||||||
|
|
||||||
|
self.hf_dataset: datasets.IterableDataset = self.load_hf_dataset()
|
||||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||||
|
|
||||||
|
max_backward_steps, max_forward_steps = self._get_window_steps()
|
||||||
|
self.backtrackable_dataset: Backtrackable = Backtrackable(
|
||||||
|
self.hf_dataset, history=max_backward_steps, lookahead=max_forward_steps
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self):
|
def fps(self):
|
||||||
return self.meta.fps
|
return self.meta.fps
|
||||||
@@ -148,20 +163,42 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
|
# TODO(fracapuano): Add support for streaming from a local folder and not only from HF Hub
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
def _get_window_steps(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns how many steps backward (& forward) should the backtrackable iterator maintain,
|
||||||
|
based on the input delta_timestamps.
|
||||||
|
"""
|
||||||
|
max_backward_steps = 1
|
||||||
|
max_forward_steps = 1
|
||||||
|
|
||||||
|
if self.delta_timestamps is not None:
|
||||||
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||||
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||||
|
|
||||||
|
# Calculate maximum backward steps needed (i.e., history size)
|
||||||
|
for delta_idx in self.delta_indices.values():
|
||||||
|
min_delta = min(delta_idx)
|
||||||
|
max_delta = max(delta_idx)
|
||||||
|
if min_delta < 0:
|
||||||
|
max_backward_steps = max(max_backward_steps, abs(min_delta))
|
||||||
|
if max_delta > 0:
|
||||||
|
max_forward_steps = max(max_forward_steps, max_delta)
|
||||||
|
|
||||||
|
return max_backward_steps, max_forward_steps
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||||
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
|
buffer_indices_generator = self._iter_random_indices(self.rng, self.buffer_size)
|
||||||
|
idx_to_backtracktable_dataset = {
|
||||||
# This buffer is populated while iterating on the dataset's shards
|
idx: self._make_backtrackable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
|
||||||
frames_buffer = []
|
|
||||||
idx_to_iterable_dataset = {
|
|
||||||
idx: self._make_iterable_dataset(self.hf_dataset.shard(self.num_shards, index=idx))
|
|
||||||
for idx in range(self.num_shards)
|
for idx in range(self.num_shards)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# This buffer is populated while iterating on the dataset's shards
|
||||||
|
frames_buffer = []
|
||||||
try:
|
try:
|
||||||
while available_shards := list(idx_to_iterable_dataset.keys()):
|
while available_shards := list(idx_to_backtracktable_dataset.keys()):
|
||||||
shard_key = next(self._infinite_generator_over_elements(available_shards))
|
shard_key = next(self._infinite_generator_over_elements(available_shards))
|
||||||
dataset = idx_to_iterable_dataset[shard_key] # selects which shard to iterate on
|
dataset = idx_to_backtracktable_dataset[shard_key] # selects which shard to iterate on
|
||||||
for frame in self.make_frame(dataset):
|
for frame in self.make_frame(dataset):
|
||||||
if len(frames_buffer) == self.buffer_size:
|
if len(frames_buffer) == self.buffer_size:
|
||||||
i = next(buffer_indices_generator)
|
i = next(buffer_indices_generator)
|
||||||
@@ -175,18 +212,18 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
RuntimeError,
|
RuntimeError,
|
||||||
StopIteration,
|
StopIteration,
|
||||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
|
): # NOTE: StopIteration inside a generator throws a RuntimeError since 3.7
|
||||||
# Remove exhausted shard
|
del idx_to_backtracktable_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||||
del idx_to_iterable_dataset[shard_key]
|
|
||||||
|
|
||||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||||
self.rng.shuffle(frames_buffer)
|
self.rng.shuffle(frames_buffer)
|
||||||
yield from frames_buffer
|
yield from frames_buffer
|
||||||
|
|
||||||
def _make_iterable_dataset(self, dataset: datasets.IterableDataset) -> Iterator:
|
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||||
return iter(dataset)
|
history, lookahead = self._get_window_steps()
|
||||||
|
return Backtrackable(dataset, history=history, lookahead=lookahead)
|
||||||
|
|
||||||
@profile
|
@profile
|
||||||
def make_frame(self, dataset_iterator: datasets.IterableDataset) -> Generator:
|
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||||
"""Makes a frame starting from a dataset iterator"""
|
"""Makes a frame starting from a dataset iterator"""
|
||||||
item = next(dataset_iterator)
|
item = next(dataset_iterator)
|
||||||
item = item_to_torch(item)
|
item = item_to_torch(item)
|
||||||
@@ -194,16 +231,20 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
# Get episode index from the item
|
# Get episode index from the item
|
||||||
ep_idx = item["episode_index"]
|
ep_idx = item["episode_index"]
|
||||||
|
|
||||||
|
# Apply delta querying logic if necessary
|
||||||
|
if self.delta_indices is not None:
|
||||||
|
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||||
|
item = {**item, **query_result, **padding}
|
||||||
|
|
||||||
# Load video frames, when needed
|
# Load video frames, when needed
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
current_ts = item["timestamp"]
|
current_ts = item["timestamp"]
|
||||||
query_timestamps = self._get_query_timestamps(current_ts, None)
|
query_indices = self.delta_indices if self.delta_indices is not None else None
|
||||||
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
item = {**video_frames, **item}
|
||||||
|
|
||||||
# Add task as a string
|
item["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||||
task_idx = item["task_index"]
|
|
||||||
item["task"] = self.meta.tasks.iloc[task_idx].name
|
|
||||||
|
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
@@ -215,8 +256,9 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
|
||||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
# never query for negative timestamps!
|
||||||
|
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
|
||||||
else:
|
else:
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
@@ -239,14 +281,145 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
||||||
|
if dataset_iterator.can_go_ahead(n_steps):
|
||||||
|
return dataset_iterator.peek_ahead(n_steps)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
||||||
|
if dataset_iterator.can_go_back(n_steps):
|
||||||
|
return dataset_iterator.peek_back(n_steps)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||||
|
"""Get frames with delta offsets using the backtrackable iterator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_item (dict): Current item from the iterator.
|
||||||
|
ep_idx (int): Episode index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||||
|
"""
|
||||||
|
current_episode_idx = current_item["episode_index"]
|
||||||
|
|
||||||
|
# Prepare results
|
||||||
|
query_result = {}
|
||||||
|
padding = {}
|
||||||
|
|
||||||
|
for key, delta_indices in self.delta_indices.items():
|
||||||
|
if key in self.meta.video_keys:
|
||||||
|
continue # visual frames are decoded separately
|
||||||
|
|
||||||
|
target_frames = []
|
||||||
|
is_pad = []
|
||||||
|
|
||||||
|
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
|
||||||
|
for delta in delta_indices:
|
||||||
|
if delta == 0:
|
||||||
|
# Current frame
|
||||||
|
target_frames.append(current_item[key])
|
||||||
|
is_pad.append(False)
|
||||||
|
|
||||||
|
elif delta < 0:
|
||||||
|
# Past frame. Use backtrackable iterator, looking back delta steps
|
||||||
|
try:
|
||||||
|
steps_back = abs(delta)
|
||||||
|
if dataset_iterator.can_peek_back(steps_back):
|
||||||
|
past_item = dataset_iterator.peek_back(steps_back)
|
||||||
|
past_item = item_to_torch(past_item)
|
||||||
|
|
||||||
|
# Check if it's from the same episode
|
||||||
|
if past_item["episode_index"] == current_episode_idx:
|
||||||
|
target_frames.append(past_item[key])
|
||||||
|
is_pad.append(False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise LookBackError("Retrieved frame is from different episode!")
|
||||||
|
else:
|
||||||
|
raise LookBackError("Cannot go back further than the history buffer!")
|
||||||
|
|
||||||
|
except LookBackError:
|
||||||
|
target_frames.append(torch.zeros_like(current_item[key]))
|
||||||
|
is_pad.append(True)
|
||||||
|
|
||||||
|
elif delta > 0:
|
||||||
|
# Future frame - read ahead from the iterator
|
||||||
|
try:
|
||||||
|
if dataset_iterator.can_peek_ahead(delta):
|
||||||
|
future_item = dataset_iterator.peek_ahead(delta)
|
||||||
|
future_item = item_to_torch(future_item)
|
||||||
|
|
||||||
|
# Check if it's from the same episode
|
||||||
|
if future_item["episode_index"] == current_episode_idx:
|
||||||
|
target_frames.append(future_item[key])
|
||||||
|
is_pad.append(False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise LookAheadError("Retrieved frame is from different episode!")
|
||||||
|
else:
|
||||||
|
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||||
|
|
||||||
|
except LookAheadError:
|
||||||
|
target_frames.append(torch.zeros_like(current_item[key]))
|
||||||
|
is_pad.append(True)
|
||||||
|
|
||||||
|
# Stack frames and add to results
|
||||||
|
if target_frames:
|
||||||
|
query_result[key] = torch.stack(target_frames)
|
||||||
|
padding[f"{key}.pad_masking"] = torch.BoolTensor(is_pad)
|
||||||
|
|
||||||
|
return query_result, padding
|
||||||
|
|
||||||
|
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||||
|
"""
|
||||||
|
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
||||||
|
"""
|
||||||
|
if delta_timestamps is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all available feature keys from the dataset metadata
|
||||||
|
available_features = set(self.meta.features.keys())
|
||||||
|
|
||||||
|
# Get all keys from delta_timestamps
|
||||||
|
delta_keys = set(delta_timestamps.keys())
|
||||||
|
|
||||||
|
# Find any keys that don't correspond to features
|
||||||
|
invalid_keys = delta_keys - available_features
|
||||||
|
|
||||||
|
if invalid_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
||||||
|
f"Available features are: {sorted(available_features)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||||
dataset = StreamingLeRobotDataset(repo_id)
|
|
||||||
|
|
||||||
for i, frame in enumerate(dataset):
|
camera_key = "observation.images.cam_right_wrist"
|
||||||
|
fps = 50
|
||||||
|
|
||||||
|
delta_timestamps = {
|
||||||
|
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||||
|
camera_key: [-1, -0.5, -0.20, 0],
|
||||||
|
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||||
|
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||||
|
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||||
|
"action": [t / fps for t in range(64)],
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||||
|
|
||||||
|
for i, frame in tqdm(enumerate(dataset)):
|
||||||
print(frame)
|
print(frame)
|
||||||
|
if i > 1000: # only stream first 10 frames
|
||||||
if i > 10: # only stream first 10 frames
|
|
||||||
break
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user