Add Streaming Dataset (#1613)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
f55c6e89f0
commit
33cad37054
116
examples/5_train_with_streaming.py
Normal file
116
examples/5_train_with_streaming.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This script demonstrates how to train a Diffusion Policy on the PushT environment,
|
||||
using a dataset processed in streaming mode.
|
||||
|
||||
Once you have trained a model with this script, you can try to evaluate it on
|
||||
examples/2_evaluate_pretrained_policy.py
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import ACTION
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
|
||||
|
||||
def main():
|
||||
# Create a directory to store the training checkpoint.
|
||||
output_directory = Path("outputs/train/example_streaming_dataset")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Selects the "best" device available
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
training_steps = 10
|
||||
log_freq = 1
|
||||
|
||||
dataset_id = (
|
||||
"aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (:
|
||||
)
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
# We can now instantiate our policy with this config and the dataset stats.
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy.
|
||||
# Here, we use delta-timestamps to only provide ground truth actions for supervision
|
||||
delta_timestamps = {
|
||||
ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)],
|
||||
}
|
||||
|
||||
# Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched
|
||||
# iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs
|
||||
dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3)
|
||||
|
||||
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=16,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
prefetch_factor=2, # loads batches with multiprocessing while policy trains
|
||||
)
|
||||
|
||||
# Run training loop.
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = {
|
||||
k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v)
|
||||
for k, v in batch.items()
|
||||
}
|
||||
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
|
||||
# batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save a policy checkpoint.
|
||||
policy.save_pretrained(output_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -37,6 +37,7 @@ class DatasetConfig:
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -52,3 +52,8 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||
|
||||
|
||||
# streaming datasets
|
||||
LOOKBACK_BACKTRACKTABLE = 100
|
||||
LOOKAHEAD_BACKTRACKTABLE = 100
|
||||
|
||||
@@ -25,6 +25,7 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDatasetMetadata,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransforms
|
||||
|
||||
IMAGENET_STATS = {
|
||||
@@ -87,15 +88,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||
)
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
if not cfg.dataset.streaming:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
else:
|
||||
dataset = StreamingLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=cfg.dataset.episodes,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
revision=cfg.dataset.revision,
|
||||
max_num_shards=cfg.num_workers,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
|
||||
@@ -129,6 +129,10 @@ class LeRobotDatasetMetadata:
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"hf://datasets/{self.repo_id}"
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
|
||||
535
src/lerobot/datasets/streaming_dataset.py
Normal file
535
src/lerobot/datasets/streaming_dataset.py
Normal file
@@ -0,0 +1,535 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Callable, Generator, Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
Backtrackable,
|
||||
LookAheadError,
|
||||
LookBackError,
|
||||
check_version_compatibility,
|
||||
find_float_index,
|
||||
get_delta_indices,
|
||||
is_float_in_list,
|
||||
item_to_torch,
|
||||
safe_shard,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoDecoderCache,
|
||||
decode_video_frames_torchcodec,
|
||||
)
|
||||
|
||||
|
||||
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
"""LeRobotDataset with streaming capabilities.
|
||||
|
||||
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||
|
||||
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||
dataset into memory.
|
||||
|
||||
Example:
|
||||
Basic usage:
|
||||
```python
|
||||
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
|
||||
# Create a streaming dataset with delta timestamps
|
||||
delta_timestamps = {
|
||||
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||
}
|
||||
|
||||
dataset = StreamingLeRobotDataset(
|
||||
repo_id="your-dataset-repo-id",
|
||||
delta_timestamps=delta_timestamps,
|
||||
streaming=True,
|
||||
buffer_size=1000,
|
||||
)
|
||||
|
||||
# Iterate over the dataset
|
||||
for i, item in enumerate(dataset):
|
||||
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||
# item will contain stacked frames according to delta_timestamps
|
||||
if i >= 10:
|
||||
break
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
revision: str | None = None,
|
||||
force_cache_sync: bool = False,
|
||||
streaming: bool = True,
|
||||
buffer_size: int = 1000,
|
||||
max_num_shards: int = 16,
|
||||
seed: int = 42,
|
||||
rng: np.random.Generator | None = None,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
"""Initialize a StreamingLeRobotDataset.
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list.
|
||||
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
|
||||
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||
seed (int, optional): Reproducibility random seed.
|
||||
rng (np.random.Generator | None, optional): Random number generator.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.streaming_from_local = root is not None
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.seed = seed
|
||||
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.streaming = streaming
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||
self.video_decoder_cache = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Load metadata
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||
|
||||
self.delta_timestamps = None
|
||||
self.delta_indices = None
|
||||
|
||||
if delta_timestamps is not None:
|
||||
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||
split="train",
|
||||
streaming=self.streaming,
|
||||
data_files="data/*/*.parquet",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||
|
||||
@property
|
||||
def num_frames(self):
|
||||
return self.meta.total_frames
|
||||
|
||||
@property
|
||||
def num_episodes(self):
|
||||
return self.meta.total_episodes
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return self.meta.fps
|
||||
|
||||
@staticmethod
|
||||
def _iter_random_indices(
|
||||
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
||||
) -> Iterator[int]:
|
||||
while True:
|
||||
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||
|
||||
@staticmethod
|
||||
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
||||
while True:
|
||||
yield rng.choice(elements)
|
||||
|
||||
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||
if self.video_decoder_cache is None:
|
||||
self.video_decoder_cache = VideoDecoderCache()
|
||||
|
||||
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||
|
||||
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||
|
||||
idx_to_backtrack_dataset = {
|
||||
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||
for idx in range(self.num_shards)
|
||||
}
|
||||
|
||||
# This buffer is populated while iterating on the dataset's shards
|
||||
# the logic is to add 2 levels of randomness:
|
||||
# (1) sample one shard at random from the ones available, and
|
||||
# (2) sample one frame from the shard sampled at (1)
|
||||
frames_buffer = []
|
||||
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
||||
|
||||
try:
|
||||
for frame in self.make_frame(backtrack_dataset):
|
||||
if len(frames_buffer) == self.buffer_size:
|
||||
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||
yield frames_buffer[i]
|
||||
frames_buffer[i] = frame
|
||||
else:
|
||||
frames_buffer.append(frame)
|
||||
break # random shard sampled, switch shard
|
||||
except (
|
||||
RuntimeError,
|
||||
StopIteration,
|
||||
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||
|
||||
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||
rng.shuffle(frames_buffer)
|
||||
yield from frames_buffer
|
||||
|
||||
def _get_window_steps(
|
||||
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||
) -> tuple[int, int]:
|
||||
if delta_timestamps is None:
|
||||
return 1, 1
|
||||
|
||||
if not dynamic_bounds:
|
||||
# Fix the windows
|
||||
lookback = LOOKBACK_BACKTRACKTABLE
|
||||
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||
else:
|
||||
# Dynamically adjust the windows based on the given delta_timesteps
|
||||
all_timestamps = sum(delta_timestamps.values(), [])
|
||||
lookback = min(all_timestamps) * self.fps
|
||||
lookahead = max(all_timestamps) * self.fps
|
||||
|
||||
# When lookback is >=0 it means no negative timesteps have been provided
|
||||
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||
|
||||
return lookback, lookahead
|
||||
|
||||
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
) -> dict[str, list[float]]:
|
||||
if indices is not None:
|
||||
return {
|
||||
key: (
|
||||
start_ts + torch.tensor(indices[key]) / self.fps
|
||||
).tolist() # NOTE: why not delta_timestamps directly?
|
||||
for key in self.delta_timestamps
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.meta.video_keys, [start_ts])
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _get_video_frame_padding_mask(
|
||||
self,
|
||||
video_frames: dict[str, torch.Tensor],
|
||||
query_timestamps: dict[str, list[float]],
|
||||
original_timestamps: dict[str, list[float]],
|
||||
) -> dict[str, torch.BoolTensor]:
|
||||
padding_mask = {}
|
||||
|
||||
for video_key, timestamps in original_timestamps.items():
|
||||
if video_key not in video_frames:
|
||||
continue # only padding on video keys that are available
|
||||
frames = []
|
||||
mask = []
|
||||
padding_frame = self._make_padding_camera_frame(video_key)
|
||||
for ts in timestamps:
|
||||
if is_float_in_list(ts, query_timestamps[video_key]):
|
||||
idx = find_float_index(ts, query_timestamps[video_key])
|
||||
frames.append(video_frames[video_key][idx, :])
|
||||
mask.append(False)
|
||||
else:
|
||||
frames.append(padding_frame)
|
||||
mask.append(True)
|
||||
|
||||
padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
|
||||
|
||||
return padding_mask
|
||||
|
||||
def make_frame(
|
||||
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
|
||||
) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||
|
||||
updates.append(video_frames)
|
||||
|
||||
if self.delta_indices is not None:
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padding_mask = self._get_video_frame_padding_mask(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
updates.append(padding_mask)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = keys_to_timestamps[key]
|
||||
# Clamp out timesteps outside of episode boundaries
|
||||
query_timestamps[key] = torch.clamp(
|
||||
torch.tensor(timestamps), *episode_boundaries_ts[key]
|
||||
).tolist()
|
||||
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
|
||||
item = {}
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||
)
|
||||
|
||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||
|
||||
return item
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
# TODO(fracapuano): Modularize this function, refactor the code
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
current_item (dict): Current item from the iterator.
|
||||
ep_idx (int): Episode index.
|
||||
|
||||
Returns:
|
||||
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||
"""
|
||||
current_episode_idx = current_item["episode_index"]
|
||||
|
||||
# Prepare results
|
||||
query_result = {}
|
||||
padding = {}
|
||||
|
||||
for key, delta_indices in self.delta_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue # visual frames are decoded separately
|
||||
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||
delta_results = {}
|
||||
|
||||
# Separate and sort deltas by difficulty (easier operations first)
|
||||
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||
zero_deltas = [d for d in delta_indices if d == 0]
|
||||
|
||||
# Process zero deltas (current frame)
|
||||
for delta in zero_deltas:
|
||||
delta_results[delta] = (
|
||||
current_item[key],
|
||||
False,
|
||||
)
|
||||
|
||||
# Process negative deltas in order of increasing difficulty
|
||||
lookback_failed = False
|
||||
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in negative_deltas:
|
||||
if lookback_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (past_item[key], False)
|
||||
last_successful_frame = past_item[key]
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookback_failed = True # All subsequent negative deltas will also fail
|
||||
|
||||
# Process positive deltas in order of increasing difficulty
|
||||
lookahead_failed = False
|
||||
last_successful_frame = current_item[key]
|
||||
|
||||
for delta in positive_deltas:
|
||||
if lookahead_failed:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
continue
|
||||
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (future_item[key], False)
|
||||
last_successful_frame = future_item[key]
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
delta_results[delta] = (last_successful_frame, True)
|
||||
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||
|
||||
# Reconstruct original order for stacking
|
||||
for delta in delta_indices:
|
||||
frame, is_padded = delta_results[delta]
|
||||
|
||||
# add batch dimension for stacking
|
||||
target_frames.append(frame) # frame.unsqueeze(0))
|
||||
is_pad.append(is_padded)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
query_result[key] = torch.stack(target_frames)
|
||||
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||
|
||||
return query_result, padding
|
||||
|
||||
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||
"""
|
||||
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||
|
||||
Raises:
|
||||
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
||||
"""
|
||||
if delta_timestamps is None:
|
||||
return
|
||||
|
||||
# Get all available feature keys from the dataset metadata
|
||||
available_features = set(self.meta.features.keys())
|
||||
|
||||
# Get all keys from delta_timestamps
|
||||
delta_keys = set(delta_timestamps.keys())
|
||||
|
||||
# Find any keys that don't correspond to features
|
||||
invalid_keys = delta_keys - available_features
|
||||
|
||||
if invalid_keys:
|
||||
raise ValueError(
|
||||
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
||||
f"Available features are: {sorted(available_features)}"
|
||||
)
|
||||
@@ -17,10 +17,11 @@ import contextlib
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -86,6 +87,8 @@ DEFAULT_FEATURES = {
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||
metadata = pq.read_metadata(parquet_path)
|
||||
@@ -776,3 +779,230 @@ def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
||||
"""
|
||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||
|
||||
|
||||
def item_to_torch(item: dict) -> dict:
|
||||
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||
|
||||
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||
|
||||
Args:
|
||||
item (dict): Dictionary of items from a dataset.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
|
||||
|
||||
def is_float_in_list(target, float_list, threshold=1e-6):
|
||||
return any(abs(target - x) <= threshold for x in float_list)
|
||||
|
||||
|
||||
def find_float_index(target, float_list, threshold=1e-6):
|
||||
for i, x in enumerate(float_list):
|
||||
if abs(target - x) <= threshold:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
class LookBackError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LookAheadError(Exception):
|
||||
"""
|
||||
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Backtrackable(Generic[T]):
|
||||
"""
|
||||
Wrap any iterator/iterable so you can step back up to `history` items
|
||||
and look ahead up to `lookahead` items.
|
||||
|
||||
This is useful for streaming datasets where you need to access previous and future items
|
||||
but can't load the entire dataset into memory.
|
||||
|
||||
Example:
|
||||
-------
|
||||
```python
|
||||
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||
|
||||
x0 = next(rev) # forward
|
||||
x1 = next(rev)
|
||||
x2 = next(rev)
|
||||
|
||||
# Look ahead
|
||||
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||
|
||||
# Look back
|
||||
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||
x0_again = rev.peek_back(2) # two items back
|
||||
|
||||
# Move backward
|
||||
x1_back = rev.prev() # back one step
|
||||
next(rev) # returns x2, continues forward from where we were
|
||||
```
|
||||
"""
|
||||
|
||||
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||
|
||||
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||
if history < 1:
|
||||
raise ValueError("history must be >= 1")
|
||||
if lookahead <= 0:
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
def __iter__(self) -> "Backtrackable[T]":
|
||||
return self
|
||||
|
||||
def __next__(self) -> T:
|
||||
# If we've stepped back, consume from back buffer first
|
||||
if self._cursor < 0: # -1 means "last item", etc.
|
||||
self._cursor += 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
# If we have items in the ahead buffer, use them first
|
||||
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||
|
||||
# Add current item to back buffer and reset cursor
|
||||
self._back_buf.append(item)
|
||||
self._cursor = 0
|
||||
return item
|
||||
|
||||
def prev(self) -> T:
|
||||
"""
|
||||
Step one item back in history and return it.
|
||||
Raises IndexError if already at the oldest buffered item.
|
||||
"""
|
||||
if len(self._back_buf) + self._cursor <= 1:
|
||||
raise LookBackError("At start of history")
|
||||
|
||||
self._cursor -= 1
|
||||
return self._back_buf[self._cursor]
|
||||
|
||||
def peek_back(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||
"""
|
||||
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||
raise LookBackError("peek_back distance out of range")
|
||||
|
||||
return self._back_buf[self._cursor - (n + 1)]
|
||||
|
||||
def peek_ahead(self, n: int = 1) -> T:
|
||||
"""
|
||||
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||
Fills the ahead buffer if necessary.
|
||||
"""
|
||||
if n < 1:
|
||||
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||
elif n > self._lookahead:
|
||||
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||
|
||||
# Fill ahead buffer if we don't have enough items
|
||||
while len(self._ahead_buf) < n:
|
||||
try:
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
|
||||
except StopIteration as err:
|
||||
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||
|
||||
return self._ahead_buf[n - 1]
|
||||
|
||||
def history(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the buffered history (most recent last).
|
||||
The list length ≤ `history` argument passed at construction.
|
||||
"""
|
||||
if self._cursor == 0:
|
||||
return list(self._back_buf)
|
||||
|
||||
# When cursor<0, slice so the order remains chronological
|
||||
return list(self._back_buf)[: self._cursor or None]
|
||||
|
||||
def lookahead_buffer(self) -> list[T]:
|
||||
"""
|
||||
Return a copy of the current lookahead buffer.
|
||||
"""
|
||||
return list(self._ahead_buf)
|
||||
|
||||
def can_peek_back(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can go back `steps` items without raising an IndexError.
|
||||
"""
|
||||
return steps <= len(self._back_buf) + self._cursor
|
||||
|
||||
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||
"""
|
||||
Check if we can peek ahead `steps` items.
|
||||
This may involve trying to fill the ahead buffer.
|
||||
"""
|
||||
if self._lookahead > 0 and steps > self._lookahead:
|
||||
return False
|
||||
|
||||
# Try to fill ahead buffer to check if we can peek that far
|
||||
try:
|
||||
while len(self._ahead_buf) < steps:
|
||||
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||
return False
|
||||
item = next(self._source)
|
||||
self._ahead_buf.append(item)
|
||||
return True
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def reset_cursor(self) -> None:
|
||||
"""
|
||||
Reset cursor to the most recent position (equivalent to calling next()
|
||||
until you're back to the latest item).
|
||||
"""
|
||||
self._cursor = 0
|
||||
|
||||
def clear_ahead_buffer(self) -> None:
|
||||
"""
|
||||
Clear the ahead buffer, discarding any pre-fetched items.
|
||||
"""
|
||||
self._ahead_buf.clear()
|
||||
|
||||
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
|
||||
"""
|
||||
Switch the source of the backtrackable to a new iterable, keeping the history.
|
||||
|
||||
This is useful when iterating over a sequence of datasets. The history from the
|
||||
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
|
||||
to the present.
|
||||
"""
|
||||
self._source = iter(new_source)
|
||||
self.clear_ahead_buffer()
|
||||
self.reset_cursor()
|
||||
|
||||
|
||||
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
|
||||
"""
|
||||
Safe shards the dataset.
|
||||
"""
|
||||
shard_idx = min(dataset.num_shards, index + 1) - 1
|
||||
|
||||
return dataset.shard(num_shards, index=shard_idx)
|
||||
|
||||
@@ -21,9 +21,11 @@ import tempfile
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -169,15 +171,68 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
class VideoDecoderCache:
|
||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def get_decoder(self, video_path: str):
|
||||
"""Get a cached decoder or create a new one."""
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
video_path = str(video_path)
|
||||
|
||||
with self._lock:
|
||||
if video_path not in self._cache:
|
||||
file_handle = fsspec.open(video_path).__enter__()
|
||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||
self._cache[video_path] = (decoder, file_handle)
|
||||
|
||||
return self._cache[video_path][0]
|
||||
|
||||
def clear(self):
|
||||
"""Clear the cache and close file handles."""
|
||||
with self._lock:
|
||||
for _, file_handle in self._cache.values():
|
||||
file_handle.close()
|
||||
self._cache.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the number of cached decoders."""
|
||||
with self._lock:
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
class FrameTimestampError(ValueError):
|
||||
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_default_decoder_cache = VideoDecoderCache()
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
decoder_cache: VideoDecoderCache | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
Args:
|
||||
video_path: Path to the video file.
|
||||
timestamps: List of timestamps to extract frames.
|
||||
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
||||
log_loaded_timestamps: Whether to log loaded timestamps.
|
||||
decoder_cache: Optional decoder cache instance. Uses default if None.
|
||||
|
||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
@@ -186,27 +241,24 @@ def decode_video_frames_torchcodec(
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
if decoder_cache is None:
|
||||
decoder_cache = _default_decoder_cache
|
||||
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
# Use cached decoder instead of creating new one each time
|
||||
decoder = decoder_cache.get_decoder(str(video_path))
|
||||
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
loaded_frames = []
|
||||
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
@@ -237,10 +289,14 @@ def decode_video_frames_torchcodec(
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
# convert to float32 in [0,1] range
|
||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||
|
||||
if not len(timestamps) == len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||
)
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
|
||||
@@ -179,10 +179,11 @@ def train(cfg: TrainPipelineConfig):
|
||||
dataset,
|
||||
num_workers=cfg.num_workers,
|
||||
batch_size=cfg.batch_size,
|
||||
shuffle=shuffle,
|
||||
shuffle=shuffle and not cfg.dataset.streaming,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
@@ -208,6 +209,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
if batch[key].dtype != torch.bool:
|
||||
batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key]
|
||||
|
||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
train_tracker, output_dict = update_policy(
|
||||
|
||||
391
tests/datasets/test_streaming.py
Normal file
391
tests/datasets/test_streaming.py
Normal file
@@ -0,0 +1,391 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||
from lerobot.datasets.utils import safe_shard
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||
rng = np.random.default_rng(streaming_ds.seed)
|
||||
buffer_size = streaming_ds.buffer_size
|
||||
num_shards = streaming_ds.num_shards
|
||||
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||
|
||||
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||
|
||||
frames_buffer = []
|
||||
expected_indices = []
|
||||
|
||||
while shard_iterators: # While there are still available shards
|
||||
available_shard_keys = list(shard_iterators.keys())
|
||||
if not available_shard_keys:
|
||||
break
|
||||
|
||||
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||
|
||||
try:
|
||||
frame_index = next(shard_iterators[shard_key])
|
||||
|
||||
if len(frames_buffer) == buffer_size:
|
||||
i = next(buffer_indices_generator)
|
||||
expected_indices.append(frames_buffer[i])
|
||||
frames_buffer[i] = frame_index
|
||||
else:
|
||||
frames_buffer.append(frame_index)
|
||||
|
||||
except StopIteration:
|
||||
del shard_iterators[shard_key] # Remove exhausted shard
|
||||
|
||||
rng.shuffle(frames_buffer)
|
||||
expected_indices.extend(frames_buffer)
|
||||
|
||||
return expected_indices
|
||||
|
||||
|
||||
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||
"""Test if are correctly accessed"""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}"
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
)
|
||||
|
||||
streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size))
|
||||
|
||||
key_checks = []
|
||||
for _ in range(ds_num_frames):
|
||||
streaming_frame = next(streaming_ds)
|
||||
frame_idx = streaming_frame["index"]
|
||||
target_frame = ds[frame_idx]
|
||||
|
||||
for key in streaming_frame:
|
||||
left = streaming_frame[key]
|
||||
right = target_frame[key]
|
||||
|
||||
if isinstance(left, str):
|
||||
check = left == right
|
||||
|
||||
elif isinstance(left, torch.Tensor):
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
elif isinstance(left, float):
|
||||
check = left == right.item() # right is a torch.Tensor
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shuffle",
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||
ds_num_frames = 400
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}"
|
||||
|
||||
lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shuffle",
|
||||
[False, True],
|
||||
)
|
||||
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
|
||||
seed = 42
|
||||
n_epochs = 3
|
||||
data_file_size_mb = 0.001
|
||||
|
||||
chunks_size = 1
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||
|
||||
lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
data_files_size_in_mb=data_file_size_mb,
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=shuffle,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||
expected_indices = get_frames_expected_order(streaming_ds)
|
||||
|
||||
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||
|
||||
for _ in range(n_epochs):
|
||||
streaming_indices = [
|
||||
frame["index"] for frame in streaming_ds
|
||||
] # NOTE: this is the same as first_epoch_indices
|
||||
frames_match = all(
|
||||
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||
)
|
||||
if shuffle:
|
||||
assert not frames_match
|
||||
else:
|
||||
assert frames_match
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_deltas, action_deltas",
|
||||
[
|
||||
([-1, -0.5, -0.20, 0], [0, 1, 2, 3]),
|
||||
([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
([-2, -1, -0.5, 0], [0, 1, 2, 3]),
|
||||
([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
],
|
||||
)
|
||||
def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas):
|
||||
ds_num_frames = 500
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 100
|
||||
|
||||
seed = 42
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||
camera_key = "phone"
|
||||
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
|
||||
streaming_ds = iter(
|
||||
StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(ds_num_frames):
|
||||
streaming_frame = next(streaming_ds)
|
||||
frame_idx = streaming_frame["index"]
|
||||
target_frame = ds[frame_idx]
|
||||
|
||||
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
||||
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
||||
)
|
||||
|
||||
key_checks = []
|
||||
for key in streaming_frame:
|
||||
left = streaming_frame[key]
|
||||
right = target_frame[key]
|
||||
|
||||
if isinstance(left, str):
|
||||
check = left == right
|
||||
|
||||
elif isinstance(left, torch.Tensor):
|
||||
if (
|
||||
key not in ds.meta.camera_keys
|
||||
and "is_pad" not in key
|
||||
and f"{key}_is_pad" in streaming_frame
|
||||
):
|
||||
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
||||
left = left[~streaming_frame[f"{key}_is_pad"]]
|
||||
right = right[~target_frame[f"{key}_is_pad"]]
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state_deltas, action_deltas",
|
||||
[
|
||||
([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]),
|
||||
([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]),
|
||||
([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||
],
|
||||
)
|
||||
def test_frames_with_delta_consistency_with_shards(
|
||||
tmp_path, lerobot_dataset_factory, state_deltas, action_deltas
|
||||
):
|
||||
ds_num_frames = 100
|
||||
ds_num_episodes = 10
|
||||
buffer_size = 10
|
||||
data_file_size_mb = 0.001
|
||||
chunks_size = 1
|
||||
|
||||
seed = 42
|
||||
|
||||
local_path = tmp_path / "test"
|
||||
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||
camera_key = "phone"
|
||||
|
||||
delta_timestamps = {
|
||||
camera_key: state_deltas,
|
||||
"state": state_deltas,
|
||||
"action": action_deltas,
|
||||
}
|
||||
|
||||
ds = lerobot_dataset_factory(
|
||||
root=local_path,
|
||||
repo_id=repo_id,
|
||||
total_episodes=ds_num_episodes,
|
||||
total_frames=ds_num_frames,
|
||||
delta_timestamps=delta_timestamps,
|
||||
data_files_size_in_mb=data_file_size_mb,
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
streaming_ds = StreamingLeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=local_path,
|
||||
buffer_size=buffer_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
delta_timestamps=delta_timestamps,
|
||||
max_num_shards=4,
|
||||
)
|
||||
|
||||
iter(streaming_ds)
|
||||
|
||||
num_shards = 4
|
||||
shards_indices = []
|
||||
for shard_idx in range(num_shards):
|
||||
shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards)
|
||||
shard_indices = [item["index"] for item in shard]
|
||||
shards_indices.append(shard_indices)
|
||||
|
||||
streaming_ds = iter(streaming_ds)
|
||||
|
||||
for i in range(ds_num_frames):
|
||||
streaming_frame = next(streaming_ds)
|
||||
frame_idx = streaming_frame["index"]
|
||||
target_frame = ds[frame_idx]
|
||||
|
||||
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
||||
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
||||
)
|
||||
|
||||
key_checks = []
|
||||
for key in streaming_frame:
|
||||
left = streaming_frame[key]
|
||||
right = target_frame[key]
|
||||
|
||||
if isinstance(left, str):
|
||||
check = left == right
|
||||
|
||||
elif isinstance(left, torch.Tensor):
|
||||
if (
|
||||
key not in ds.meta.camera_keys
|
||||
and "is_pad" not in key
|
||||
and f"{key}_is_pad" in streaming_frame
|
||||
):
|
||||
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
||||
left = left[~streaming_frame[f"{key}_is_pad"]]
|
||||
right = right[~target_frame[f"{key}_is_pad"]]
|
||||
|
||||
check = torch.allclose(left, right) and left.shape == right.shape
|
||||
|
||||
elif isinstance(left, float):
|
||||
check = left == right.item() # right is a torch.Tensor
|
||||
|
||||
key_checks.append((key, check))
|
||||
|
||||
assert all(t[1] for t in key_checks), (
|
||||
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
||||
)
|
||||
Reference in New Issue
Block a user