Add Streaming Dataset (#1613)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
f55c6e89f0
commit
33cad37054
116
examples/5_train_with_streaming.py
Normal file
116
examples/5_train_with_streaming.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""This script demonstrates how to train a Diffusion Policy on the PushT environment,
|
||||||
|
using a dataset processed in streaming mode.
|
||||||
|
|
||||||
|
Once you have trained a model with this script, you can try to evaluate it on
|
||||||
|
examples/2_evaluate_pretrained_policy.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType
|
||||||
|
from lerobot.constants import ACTION
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
|
from lerobot.datasets.utils import dataset_to_policy_features
|
||||||
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
|
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create a directory to store the training checkpoint.
|
||||||
|
output_directory = Path("outputs/train/example_streaming_dataset")
|
||||||
|
output_directory.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Selects the "best" device available
|
||||||
|
device = (
|
||||||
|
torch.device("cuda")
|
||||||
|
if torch.cuda.is_available()
|
||||||
|
else torch.device("mps")
|
||||||
|
if torch.backends.mps.is_available()
|
||||||
|
else torch.device("cpu")
|
||||||
|
)
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
training_steps = 10
|
||||||
|
log_freq = 1
|
||||||
|
|
||||||
|
dataset_id = (
|
||||||
|
"aractingi/droid_1.0.1" # 26M frames! Would require 4TB of disk space if installed locally (:
|
||||||
|
)
|
||||||
|
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||||
|
features = dataset_to_policy_features(dataset_metadata.features)
|
||||||
|
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||||
|
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||||
|
|
||||||
|
# We can now instantiate our policy with this config and the dataset stats.
|
||||||
|
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||||
|
policy = ACTPolicy(cfg, dataset_stats=dataset_metadata.stats)
|
||||||
|
policy.train()
|
||||||
|
policy.to(device)
|
||||||
|
|
||||||
|
# Delta timestamps are used to (1) augment frames used during training and (2) supervise the policy.
|
||||||
|
# Here, we use delta-timestamps to only provide ground truth actions for supervision
|
||||||
|
delta_timestamps = {
|
||||||
|
ACTION: [t / dataset_metadata.fps for t in range(cfg.n_action_steps)],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Instantiating the training dataset in streaming mode allows to not consume up memory as the data is fetched
|
||||||
|
# iteratively rather than being load into memory all at once. Retrieved frames are shuffled across epochs
|
||||||
|
dataset = StreamingLeRobotDataset(dataset_id, delta_timestamps=delta_timestamps, tolerance_s=1e-3)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=16,
|
||||||
|
pin_memory=device.type != "cpu",
|
||||||
|
drop_last=True,
|
||||||
|
prefetch_factor=2, # loads batches with multiprocessing while policy trains
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run training loop.
|
||||||
|
step = 0
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
for batch in dataloader:
|
||||||
|
batch = {
|
||||||
|
k: (v.type(torch.float32) if isinstance(v, torch.Tensor) and v.dtype != torch.bool else v)
|
||||||
|
for k, v in batch.items()
|
||||||
|
}
|
||||||
|
batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||||
|
|
||||||
|
# batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
|
||||||
|
loss, _ = policy.forward(batch)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if step % log_freq == 0:
|
||||||
|
print(f"step: {step} loss: {loss.item():.3f}")
|
||||||
|
step += 1
|
||||||
|
if step >= training_steps:
|
||||||
|
done = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save a policy checkpoint.
|
||||||
|
policy.save_pretrained(output_directory)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -37,6 +37,7 @@ class DatasetConfig:
|
|||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
use_imagenet_stats: bool = True
|
use_imagenet_stats: bool = True
|
||||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||||
|
streaming: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -52,3 +52,8 @@ HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expandu
|
|||||||
# calibration dir
|
# calibration dir
|
||||||
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
default_calibration_path = HF_LEROBOT_HOME / "calibration"
|
||||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||||
|
|
||||||
|
|
||||||
|
# streaming datasets
|
||||||
|
LOOKBACK_BACKTRACKTABLE = 100
|
||||||
|
LOOKAHEAD_BACKTRACKTABLE = 100
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from lerobot.datasets.lerobot_dataset import (
|
|||||||
LeRobotDatasetMetadata,
|
LeRobotDatasetMetadata,
|
||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
)
|
)
|
||||||
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
from lerobot.datasets.transforms import ImageTransforms
|
from lerobot.datasets.transforms import ImageTransforms
|
||||||
|
|
||||||
IMAGENET_STATS = {
|
IMAGENET_STATS = {
|
||||||
@@ -87,15 +88,26 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
|
||||||
)
|
)
|
||||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||||
dataset = LeRobotDataset(
|
if not cfg.dataset.streaming:
|
||||||
cfg.dataset.repo_id,
|
dataset = LeRobotDataset(
|
||||||
root=cfg.dataset.root,
|
cfg.dataset.repo_id,
|
||||||
episodes=cfg.dataset.episodes,
|
root=cfg.dataset.root,
|
||||||
delta_timestamps=delta_timestamps,
|
episodes=cfg.dataset.episodes,
|
||||||
image_transforms=image_transforms,
|
delta_timestamps=delta_timestamps,
|
||||||
revision=cfg.dataset.revision,
|
image_transforms=image_transforms,
|
||||||
video_backend=cfg.dataset.video_backend,
|
revision=cfg.dataset.revision,
|
||||||
)
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset = StreamingLeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=cfg.dataset.episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=image_transforms,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
max_num_shards=cfg.num_workers,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
dataset = MultiLeRobotDataset(
|
dataset = MultiLeRobotDataset(
|
||||||
|
|||||||
@@ -129,6 +129,10 @@ class LeRobotDatasetMetadata:
|
|||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url_root(self) -> str:
|
||||||
|
return f"hf://datasets/{self.repo_id}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _version(self) -> packaging.version.Version:
|
def _version(self) -> packaging.version.Version:
|
||||||
"""Codebase version used to create this dataset."""
|
"""Codebase version used to create this dataset."""
|
||||||
|
|||||||
535
src/lerobot/datasets/streaming_dataset.py
Normal file
535
src/lerobot/datasets/streaming_dataset.py
Normal file
@@ -0,0 +1,535 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from collections.abc import Callable, Generator, Iterator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from lerobot.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
||||||
|
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||||
|
from lerobot.datasets.utils import (
|
||||||
|
Backtrackable,
|
||||||
|
LookAheadError,
|
||||||
|
LookBackError,
|
||||||
|
check_version_compatibility,
|
||||||
|
find_float_index,
|
||||||
|
get_delta_indices,
|
||||||
|
is_float_in_list,
|
||||||
|
item_to_torch,
|
||||||
|
safe_shard,
|
||||||
|
)
|
||||||
|
from lerobot.datasets.video_utils import (
|
||||||
|
VideoDecoderCache,
|
||||||
|
decode_video_frames_torchcodec,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||||
|
"""LeRobotDataset with streaming capabilities.
|
||||||
|
|
||||||
|
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
||||||
|
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
||||||
|
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
||||||
|
|
||||||
|
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
||||||
|
items, allowing us to access previous frames for delta timestamps without loading the entire
|
||||||
|
dataset into memory.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Basic usage:
|
||||||
|
```python
|
||||||
|
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
|
|
||||||
|
# Create a streaming dataset with delta timestamps
|
||||||
|
delta_timestamps = {
|
||||||
|
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
||||||
|
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = StreamingLeRobotDataset(
|
||||||
|
repo_id="your-dataset-repo-id",
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
streaming=True,
|
||||||
|
buffer_size=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate over the dataset
|
||||||
|
for i, item in enumerate(dataset):
|
||||||
|
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
||||||
|
# item will contain stacked frames according to delta_timestamps
|
||||||
|
if i >= 10:
|
||||||
|
break
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
root: str | Path | None = None,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
image_transforms: Callable | None = None,
|
||||||
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
tolerance_s: float = 1e-4,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_cache_sync: bool = False,
|
||||||
|
streaming: bool = True,
|
||||||
|
buffer_size: int = 1000,
|
||||||
|
max_num_shards: int = 16,
|
||||||
|
seed: int = 42,
|
||||||
|
rng: np.random.Generator | None = None,
|
||||||
|
shuffle: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize a StreamingLeRobotDataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||||
|
root (Path | None, optional): Local directory to use for downloading/writing files.
|
||||||
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||||
|
their episode_index in this list.
|
||||||
|
image_transforms (Callable | None, optional): Transform to apply to image data.
|
||||||
|
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
|
||||||
|
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
||||||
|
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
||||||
|
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
||||||
|
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
||||||
|
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
||||||
|
seed (int, optional): Reproducibility random seed.
|
||||||
|
rng (np.random.Generator | None, optional): Random number generator.
|
||||||
|
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.repo_id = repo_id
|
||||||
|
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||||
|
self.streaming_from_local = root is not None
|
||||||
|
|
||||||
|
self.image_transforms = image_transforms
|
||||||
|
self.episodes = episodes
|
||||||
|
self.tolerance_s = tolerance_s
|
||||||
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
|
self.seed = seed
|
||||||
|
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
self.streaming = streaming
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
|
||||||
|
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||||
|
self.video_decoder_cache = None
|
||||||
|
|
||||||
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
self.meta = LeRobotDatasetMetadata(
|
||||||
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||||
|
)
|
||||||
|
# Check version
|
||||||
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
||||||
|
|
||||||
|
self.delta_timestamps = None
|
||||||
|
self.delta_indices = None
|
||||||
|
|
||||||
|
if delta_timestamps is not None:
|
||||||
|
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
||||||
|
self.delta_timestamps = delta_timestamps
|
||||||
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||||
|
|
||||||
|
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
||||||
|
self.repo_id if not self.streaming_from_local else str(self.root),
|
||||||
|
split="train",
|
||||||
|
streaming=self.streaming,
|
||||||
|
data_files="data/*/*.parquet",
|
||||||
|
revision=self.revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames(self):
|
||||||
|
return self.meta.total_frames
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self):
|
||||||
|
return self.meta.total_episodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self):
|
||||||
|
return self.meta.fps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _iter_random_indices(
|
||||||
|
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
||||||
|
) -> Iterator[int]:
|
||||||
|
while True:
|
||||||
|
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
||||||
|
while True:
|
||||||
|
yield rng.choice(elements)
|
||||||
|
|
||||||
|
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
||||||
|
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
||||||
|
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
||||||
|
# in parallel, feeding a queue from which this iterator will yield processed items.
|
||||||
|
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
||||||
|
if self.video_decoder_cache is None:
|
||||||
|
self.video_decoder_cache = VideoDecoderCache()
|
||||||
|
|
||||||
|
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
||||||
|
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
||||||
|
|
||||||
|
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
||||||
|
|
||||||
|
idx_to_backtrack_dataset = {
|
||||||
|
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
||||||
|
for idx in range(self.num_shards)
|
||||||
|
}
|
||||||
|
|
||||||
|
# This buffer is populated while iterating on the dataset's shards
|
||||||
|
# the logic is to add 2 levels of randomness:
|
||||||
|
# (1) sample one shard at random from the ones available, and
|
||||||
|
# (2) sample one frame from the shard sampled at (1)
|
||||||
|
frames_buffer = []
|
||||||
|
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
||||||
|
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
||||||
|
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
||||||
|
|
||||||
|
try:
|
||||||
|
for frame in self.make_frame(backtrack_dataset):
|
||||||
|
if len(frames_buffer) == self.buffer_size:
|
||||||
|
i = next(buffer_indices_generator) # samples a element from the buffer
|
||||||
|
yield frames_buffer[i]
|
||||||
|
frames_buffer[i] = frame
|
||||||
|
else:
|
||||||
|
frames_buffer.append(frame)
|
||||||
|
break # random shard sampled, switch shard
|
||||||
|
except (
|
||||||
|
RuntimeError,
|
||||||
|
StopIteration,
|
||||||
|
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
||||||
|
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
||||||
|
|
||||||
|
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
||||||
|
rng.shuffle(frames_buffer)
|
||||||
|
yield from frames_buffer
|
||||||
|
|
||||||
|
def _get_window_steps(
|
||||||
|
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
if delta_timestamps is None:
|
||||||
|
return 1, 1
|
||||||
|
|
||||||
|
if not dynamic_bounds:
|
||||||
|
# Fix the windows
|
||||||
|
lookback = LOOKBACK_BACKTRACKTABLE
|
||||||
|
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
||||||
|
else:
|
||||||
|
# Dynamically adjust the windows based on the given delta_timesteps
|
||||||
|
all_timestamps = sum(delta_timestamps.values(), [])
|
||||||
|
lookback = min(all_timestamps) * self.fps
|
||||||
|
lookahead = max(all_timestamps) * self.fps
|
||||||
|
|
||||||
|
# When lookback is >=0 it means no negative timesteps have been provided
|
||||||
|
lookback = 0 if lookback >= 0 else (lookback * -1)
|
||||||
|
|
||||||
|
return lookback, lookahead
|
||||||
|
|
||||||
|
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
||||||
|
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
||||||
|
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
||||||
|
|
||||||
|
def _make_timestamps_from_indices(
|
||||||
|
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
if indices is not None:
|
||||||
|
return {
|
||||||
|
key: (
|
||||||
|
start_ts + torch.tensor(indices[key]) / self.fps
|
||||||
|
).tolist() # NOTE: why not delta_timestamps directly?
|
||||||
|
for key in self.delta_timestamps
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return dict.fromkeys(self.meta.video_keys, [start_ts])
|
||||||
|
|
||||||
|
def _make_padding_camera_frame(self, camera_key: str):
|
||||||
|
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
||||||
|
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||||
|
|
||||||
|
def _get_video_frame_padding_mask(
|
||||||
|
self,
|
||||||
|
video_frames: dict[str, torch.Tensor],
|
||||||
|
query_timestamps: dict[str, list[float]],
|
||||||
|
original_timestamps: dict[str, list[float]],
|
||||||
|
) -> dict[str, torch.BoolTensor]:
|
||||||
|
padding_mask = {}
|
||||||
|
|
||||||
|
for video_key, timestamps in original_timestamps.items():
|
||||||
|
if video_key not in video_frames:
|
||||||
|
continue # only padding on video keys that are available
|
||||||
|
frames = []
|
||||||
|
mask = []
|
||||||
|
padding_frame = self._make_padding_camera_frame(video_key)
|
||||||
|
for ts in timestamps:
|
||||||
|
if is_float_in_list(ts, query_timestamps[video_key]):
|
||||||
|
idx = find_float_index(ts, query_timestamps[video_key])
|
||||||
|
frames.append(video_frames[video_key][idx, :])
|
||||||
|
mask.append(False)
|
||||||
|
else:
|
||||||
|
frames.append(padding_frame)
|
||||||
|
mask.append(True)
|
||||||
|
|
||||||
|
padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
|
||||||
|
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def make_frame(
|
||||||
|
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
|
||||||
|
) -> Generator:
|
||||||
|
"""Makes a frame starting from a dataset iterator"""
|
||||||
|
item = next(dataset_iterator)
|
||||||
|
item = item_to_torch(item)
|
||||||
|
|
||||||
|
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
||||||
|
|
||||||
|
# Get episode index from the item
|
||||||
|
ep_idx = item["episode_index"]
|
||||||
|
|
||||||
|
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||||
|
current_ts = item["index"] / self.fps
|
||||||
|
|
||||||
|
episode_boundaries_ts = {
|
||||||
|
key: (
|
||||||
|
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||||
|
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||||
|
)
|
||||||
|
for key in self.meta.video_keys
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply delta querying logic if necessary
|
||||||
|
if self.delta_indices is not None:
|
||||||
|
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||||
|
updates.append(query_result)
|
||||||
|
updates.append(padding)
|
||||||
|
|
||||||
|
# Load video frames, when needed
|
||||||
|
if len(self.meta.video_keys) > 0:
|
||||||
|
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||||
|
|
||||||
|
# Some timestamps might not result available considering the episode's boundaries
|
||||||
|
query_timestamps = self._get_query_timestamps(
|
||||||
|
current_ts, self.delta_indices, episode_boundaries_ts
|
||||||
|
)
|
||||||
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
|
|
||||||
|
if self.image_transforms is not None:
|
||||||
|
image_keys = self.meta.camera_keys
|
||||||
|
for cam in image_keys:
|
||||||
|
video_frames[cam] = self.image_transforms(video_frames[cam])
|
||||||
|
|
||||||
|
updates.append(video_frames)
|
||||||
|
|
||||||
|
if self.delta_indices is not None:
|
||||||
|
# We always return the same number of frames. Unavailable frames are padded.
|
||||||
|
padding_mask = self._get_video_frame_padding_mask(
|
||||||
|
video_frames, query_timestamps, original_timestamps
|
||||||
|
)
|
||||||
|
updates.append(padding_mask)
|
||||||
|
|
||||||
|
result = item.copy()
|
||||||
|
for update in updates:
|
||||||
|
result.update(update)
|
||||||
|
|
||||||
|
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||||
|
|
||||||
|
yield result
|
||||||
|
|
||||||
|
def _get_query_timestamps(
|
||||||
|
self,
|
||||||
|
current_ts: float,
|
||||||
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
query_timestamps = {}
|
||||||
|
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
||||||
|
for key in self.meta.video_keys:
|
||||||
|
if query_indices is not None and key in query_indices:
|
||||||
|
timestamps = keys_to_timestamps[key]
|
||||||
|
# Clamp out timesteps outside of episode boundaries
|
||||||
|
query_timestamps[key] = torch.clamp(
|
||||||
|
torch.tensor(timestamps), *episode_boundaries_ts[key]
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
|
return query_timestamps
|
||||||
|
|
||||||
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||||
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||||
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||||
|
the main process and a subprocess fails to access it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
item = {}
|
||||||
|
for video_key, query_ts in query_timestamps.items():
|
||||||
|
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||||
|
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||||
|
frames = decode_video_frames_torchcodec(
|
||||||
|
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||||
|
# TODO(fracapuano): Modularize this function, refactor the code
|
||||||
|
"""Get frames with delta offsets using the backtrackable iterator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_item (dict): Current item from the iterator.
|
||||||
|
ep_idx (int): Episode index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
||||||
|
"""
|
||||||
|
current_episode_idx = current_item["episode_index"]
|
||||||
|
|
||||||
|
# Prepare results
|
||||||
|
query_result = {}
|
||||||
|
padding = {}
|
||||||
|
|
||||||
|
for key, delta_indices in self.delta_indices.items():
|
||||||
|
if key in self.meta.video_keys:
|
||||||
|
continue # visual frames are decoded separately
|
||||||
|
|
||||||
|
target_frames = []
|
||||||
|
is_pad = []
|
||||||
|
|
||||||
|
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||||
|
delta_results = {}
|
||||||
|
|
||||||
|
# Separate and sort deltas by difficulty (easier operations first)
|
||||||
|
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||||
|
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||||
|
zero_deltas = [d for d in delta_indices if d == 0]
|
||||||
|
|
||||||
|
# Process zero deltas (current frame)
|
||||||
|
for delta in zero_deltas:
|
||||||
|
delta_results[delta] = (
|
||||||
|
current_item[key],
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process negative deltas in order of increasing difficulty
|
||||||
|
lookback_failed = False
|
||||||
|
|
||||||
|
last_successful_frame = current_item[key]
|
||||||
|
|
||||||
|
for delta in negative_deltas:
|
||||||
|
if lookback_failed:
|
||||||
|
delta_results[delta] = (last_successful_frame, True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
steps_back = abs(delta)
|
||||||
|
if dataset_iterator.can_peek_back(steps_back):
|
||||||
|
past_item = dataset_iterator.peek_back(steps_back)
|
||||||
|
past_item = item_to_torch(past_item)
|
||||||
|
|
||||||
|
if past_item["episode_index"] == current_episode_idx:
|
||||||
|
delta_results[delta] = (past_item[key], False)
|
||||||
|
last_successful_frame = past_item[key]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise LookBackError("Retrieved frame is from different episode!")
|
||||||
|
else:
|
||||||
|
raise LookBackError("Cannot go back further than the history buffer!")
|
||||||
|
|
||||||
|
except LookBackError:
|
||||||
|
delta_results[delta] = (last_successful_frame, True)
|
||||||
|
lookback_failed = True # All subsequent negative deltas will also fail
|
||||||
|
|
||||||
|
# Process positive deltas in order of increasing difficulty
|
||||||
|
lookahead_failed = False
|
||||||
|
last_successful_frame = current_item[key]
|
||||||
|
|
||||||
|
for delta in positive_deltas:
|
||||||
|
if lookahead_failed:
|
||||||
|
delta_results[delta] = (last_successful_frame, True)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if dataset_iterator.can_peek_ahead(delta):
|
||||||
|
future_item = dataset_iterator.peek_ahead(delta)
|
||||||
|
future_item = item_to_torch(future_item)
|
||||||
|
|
||||||
|
if future_item["episode_index"] == current_episode_idx:
|
||||||
|
delta_results[delta] = (future_item[key], False)
|
||||||
|
last_successful_frame = future_item[key]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise LookAheadError("Retrieved frame is from different episode!")
|
||||||
|
else:
|
||||||
|
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||||
|
|
||||||
|
except LookAheadError:
|
||||||
|
delta_results[delta] = (last_successful_frame, True)
|
||||||
|
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||||
|
|
||||||
|
# Reconstruct original order for stacking
|
||||||
|
for delta in delta_indices:
|
||||||
|
frame, is_padded = delta_results[delta]
|
||||||
|
|
||||||
|
# add batch dimension for stacking
|
||||||
|
target_frames.append(frame) # frame.unsqueeze(0))
|
||||||
|
is_pad.append(is_padded)
|
||||||
|
|
||||||
|
# Stack frames and add to results
|
||||||
|
if target_frames:
|
||||||
|
query_result[key] = torch.stack(target_frames)
|
||||||
|
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
||||||
|
|
||||||
|
return query_result, padding
|
||||||
|
|
||||||
|
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
||||||
|
"""
|
||||||
|
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
||||||
|
"""
|
||||||
|
if delta_timestamps is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all available feature keys from the dataset metadata
|
||||||
|
available_features = set(self.meta.features.keys())
|
||||||
|
|
||||||
|
# Get all keys from delta_timestamps
|
||||||
|
delta_keys = set(delta_timestamps.keys())
|
||||||
|
|
||||||
|
# Find any keys that don't correspond to features
|
||||||
|
invalid_keys = delta_keys - available_features
|
||||||
|
|
||||||
|
if invalid_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
||||||
|
f"Available features are: {sorted(available_features)}"
|
||||||
|
)
|
||||||
@@ -17,10 +17,11 @@ import contextlib
|
|||||||
import importlib.resources
|
import importlib.resources
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Iterator
|
from collections import deque
|
||||||
|
from collections.abc import Iterable, Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any, Deque, Generic, TypeVar
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -86,6 +87,8 @@ DEFAULT_FEATURES = {
|
|||||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
def get_parquet_file_size_in_mb(parquet_path: str | Path) -> float:
|
||||||
metadata = pq.read_metadata(parquet_path)
|
metadata = pq.read_metadata(parquet_path)
|
||||||
@@ -776,3 +779,230 @@ def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None:
|
|||||||
"""
|
"""
|
||||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||||
|
|
||||||
|
|
||||||
|
def item_to_torch(item: dict) -> dict:
|
||||||
|
"""Convert all items in a dictionary to PyTorch tensors where appropriate.
|
||||||
|
|
||||||
|
This function is used to convert an item from a streaming dataset to PyTorch tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item (dict): Dictionary of items from a dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||||
|
"""
|
||||||
|
for key, val in item.items():
|
||||||
|
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||||
|
# Convert numpy arrays and lists to torch tensors
|
||||||
|
item[key] = torch.tensor(val)
|
||||||
|
return item
|
||||||
|
|
||||||
|
|
||||||
|
def is_float_in_list(target, float_list, threshold=1e-6):
|
||||||
|
return any(abs(target - x) <= threshold for x in float_list)
|
||||||
|
|
||||||
|
|
||||||
|
def find_float_index(target, float_list, threshold=1e-6):
|
||||||
|
for i, x in enumerate(float_list):
|
||||||
|
if abs(target - x) <= threshold:
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
|
class LookBackError(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised when trying to look back in the history of a Backtrackable object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LookAheadError(Exception):
|
||||||
|
"""
|
||||||
|
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Backtrackable(Generic[T]):
|
||||||
|
"""
|
||||||
|
Wrap any iterator/iterable so you can step back up to `history` items
|
||||||
|
and look ahead up to `lookahead` items.
|
||||||
|
|
||||||
|
This is useful for streaming datasets where you need to access previous and future items
|
||||||
|
but can't load the entire dataset into memory.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
-------
|
||||||
|
```python
|
||||||
|
ds = load_dataset("c4", "en", streaming=True, split="train")
|
||||||
|
rev = Backtrackable(ds, history=3, lookahead=2)
|
||||||
|
|
||||||
|
x0 = next(rev) # forward
|
||||||
|
x1 = next(rev)
|
||||||
|
x2 = next(rev)
|
||||||
|
|
||||||
|
# Look ahead
|
||||||
|
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
||||||
|
x4_peek = rev.peek_ahead(2) # two items ahead
|
||||||
|
|
||||||
|
# Look back
|
||||||
|
x1_again = rev.peek_back(1) # previous item without moving cursor
|
||||||
|
x0_again = rev.peek_back(2) # two items back
|
||||||
|
|
||||||
|
# Move backward
|
||||||
|
x1_back = rev.prev() # back one step
|
||||||
|
next(rev) # returns x2, continues forward from where we were
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
||||||
|
|
||||||
|
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
||||||
|
if history < 1:
|
||||||
|
raise ValueError("history must be >= 1")
|
||||||
|
if lookahead <= 0:
|
||||||
|
raise ValueError("lookahead must be > 0")
|
||||||
|
|
||||||
|
self._source: Iterator[T] = iter(iterable)
|
||||||
|
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||||
|
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||||
|
self._cursor: int = 0
|
||||||
|
self._history = history
|
||||||
|
self._lookahead = lookahead
|
||||||
|
|
||||||
|
def __iter__(self) -> "Backtrackable[T]":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> T:
|
||||||
|
# If we've stepped back, consume from back buffer first
|
||||||
|
if self._cursor < 0: # -1 means "last item", etc.
|
||||||
|
self._cursor += 1
|
||||||
|
return self._back_buf[self._cursor]
|
||||||
|
|
||||||
|
# If we have items in the ahead buffer, use them first
|
||||||
|
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
||||||
|
|
||||||
|
# Add current item to back buffer and reset cursor
|
||||||
|
self._back_buf.append(item)
|
||||||
|
self._cursor = 0
|
||||||
|
return item
|
||||||
|
|
||||||
|
def prev(self) -> T:
|
||||||
|
"""
|
||||||
|
Step one item back in history and return it.
|
||||||
|
Raises IndexError if already at the oldest buffered item.
|
||||||
|
"""
|
||||||
|
if len(self._back_buf) + self._cursor <= 1:
|
||||||
|
raise LookBackError("At start of history")
|
||||||
|
|
||||||
|
self._cursor -= 1
|
||||||
|
return self._back_buf[self._cursor]
|
||||||
|
|
||||||
|
def peek_back(self, n: int = 1) -> T:
|
||||||
|
"""
|
||||||
|
Look `n` items back (n=1 == previous item) without moving the cursor.
|
||||||
|
"""
|
||||||
|
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
||||||
|
raise LookBackError("peek_back distance out of range")
|
||||||
|
|
||||||
|
return self._back_buf[self._cursor - (n + 1)]
|
||||||
|
|
||||||
|
def peek_ahead(self, n: int = 1) -> T:
|
||||||
|
"""
|
||||||
|
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
||||||
|
Fills the ahead buffer if necessary.
|
||||||
|
"""
|
||||||
|
if n < 1:
|
||||||
|
raise LookAheadError("peek_ahead distance must be 1 or more")
|
||||||
|
elif n > self._lookahead:
|
||||||
|
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
||||||
|
|
||||||
|
# Fill ahead buffer if we don't have enough items
|
||||||
|
while len(self._ahead_buf) < n:
|
||||||
|
try:
|
||||||
|
item = next(self._source)
|
||||||
|
self._ahead_buf.append(item)
|
||||||
|
|
||||||
|
except StopIteration as err:
|
||||||
|
raise LookAheadError("peek_ahead: not enough items in source") from err
|
||||||
|
|
||||||
|
return self._ahead_buf[n - 1]
|
||||||
|
|
||||||
|
def history(self) -> list[T]:
|
||||||
|
"""
|
||||||
|
Return a copy of the buffered history (most recent last).
|
||||||
|
The list length ≤ `history` argument passed at construction.
|
||||||
|
"""
|
||||||
|
if self._cursor == 0:
|
||||||
|
return list(self._back_buf)
|
||||||
|
|
||||||
|
# When cursor<0, slice so the order remains chronological
|
||||||
|
return list(self._back_buf)[: self._cursor or None]
|
||||||
|
|
||||||
|
def lookahead_buffer(self) -> list[T]:
|
||||||
|
"""
|
||||||
|
Return a copy of the current lookahead buffer.
|
||||||
|
"""
|
||||||
|
return list(self._ahead_buf)
|
||||||
|
|
||||||
|
def can_peek_back(self, steps: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Check if we can go back `steps` items without raising an IndexError.
|
||||||
|
"""
|
||||||
|
return steps <= len(self._back_buf) + self._cursor
|
||||||
|
|
||||||
|
def can_peek_ahead(self, steps: int = 1) -> bool:
|
||||||
|
"""
|
||||||
|
Check if we can peek ahead `steps` items.
|
||||||
|
This may involve trying to fill the ahead buffer.
|
||||||
|
"""
|
||||||
|
if self._lookahead > 0 and steps > self._lookahead:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try to fill ahead buffer to check if we can peek that far
|
||||||
|
try:
|
||||||
|
while len(self._ahead_buf) < steps:
|
||||||
|
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
||||||
|
return False
|
||||||
|
item = next(self._source)
|
||||||
|
self._ahead_buf.append(item)
|
||||||
|
return True
|
||||||
|
except StopIteration:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def reset_cursor(self) -> None:
|
||||||
|
"""
|
||||||
|
Reset cursor to the most recent position (equivalent to calling next()
|
||||||
|
until you're back to the latest item).
|
||||||
|
"""
|
||||||
|
self._cursor = 0
|
||||||
|
|
||||||
|
def clear_ahead_buffer(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear the ahead buffer, discarding any pre-fetched items.
|
||||||
|
"""
|
||||||
|
self._ahead_buf.clear()
|
||||||
|
|
||||||
|
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
|
||||||
|
"""
|
||||||
|
Switch the source of the backtrackable to a new iterable, keeping the history.
|
||||||
|
|
||||||
|
This is useful when iterating over a sequence of datasets. The history from the
|
||||||
|
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
|
||||||
|
to the present.
|
||||||
|
"""
|
||||||
|
self._source = iter(new_source)
|
||||||
|
self.clear_ahead_buffer()
|
||||||
|
self.reset_cursor()
|
||||||
|
|
||||||
|
|
||||||
|
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
|
||||||
|
"""
|
||||||
|
Safe shards the dataset.
|
||||||
|
"""
|
||||||
|
shard_idx = min(dataset.num_shards, index + 1) - 1
|
||||||
|
|
||||||
|
return dataset.shard(num_shards, index=shard_idx)
|
||||||
|
|||||||
@@ -21,9 +21,11 @@ import tempfile
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
import fsspec
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
@@ -169,15 +171,68 @@ def decode_video_frames_torchvision(
|
|||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDecoderCache:
|
||||||
|
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._cache: dict[str, tuple[Any, Any]] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def get_decoder(self, video_path: str):
|
||||||
|
"""Get a cached decoder or create a new one."""
|
||||||
|
if importlib.util.find_spec("torchcodec"):
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
else:
|
||||||
|
raise ImportError("torchcodec is required but not available.")
|
||||||
|
|
||||||
|
video_path = str(video_path)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if video_path not in self._cache:
|
||||||
|
file_handle = fsspec.open(video_path).__enter__()
|
||||||
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||||
|
self._cache[video_path] = (decoder, file_handle)
|
||||||
|
|
||||||
|
return self._cache[video_path][0]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear the cache and close file handles."""
|
||||||
|
with self._lock:
|
||||||
|
for _, file_handle in self._cache.values():
|
||||||
|
file_handle.close()
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
"""Return the number of cached decoders."""
|
||||||
|
with self._lock:
|
||||||
|
return len(self._cache)
|
||||||
|
|
||||||
|
|
||||||
|
class FrameTimestampError(ValueError):
|
||||||
|
"""Helper error to indicate the retrieved timestamps exceed the queried ones"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_default_decoder_cache = VideoDecoderCache()
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchcodec(
|
def decode_video_frames_torchcodec(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
device: str = "cpu",
|
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
|
decoder_cache: VideoDecoderCache | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video_path: Path to the video file.
|
||||||
|
timestamps: List of timestamps to extract frames.
|
||||||
|
tolerance_s: Allowed deviation in seconds for frame retrieval.
|
||||||
|
log_loaded_timestamps: Whether to log loaded timestamps.
|
||||||
|
decoder_cache: Optional decoder cache instance. Uses default if None.
|
||||||
|
|
||||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||||
|
|
||||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||||
@@ -186,27 +241,24 @@ def decode_video_frames_torchcodec(
|
|||||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||||
"""
|
"""
|
||||||
|
if decoder_cache is None:
|
||||||
|
decoder_cache = _default_decoder_cache
|
||||||
|
|
||||||
if importlib.util.find_spec("torchcodec"):
|
# Use cached decoder instead of creating new one each time
|
||||||
from torchcodec.decoders import VideoDecoder
|
decoder = decoder_cache.get_decoder(str(video_path))
|
||||||
else:
|
|
||||||
raise ImportError("torchcodec is required but not available.")
|
|
||||||
|
|
||||||
# initialize video decoder
|
|
||||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
|
||||||
loaded_frames = []
|
|
||||||
loaded_ts = []
|
loaded_ts = []
|
||||||
|
loaded_frames = []
|
||||||
|
|
||||||
# get metadata for frame information
|
# get metadata for frame information
|
||||||
metadata = decoder.metadata
|
metadata = decoder.metadata
|
||||||
average_fps = metadata.average_fps
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
# convert timestamps to frame indices
|
# convert timestamps to frame indices
|
||||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
# retrieve frames based on indices
|
# retrieve frames based on indices
|
||||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=True):
|
||||||
loaded_frames.append(frame)
|
loaded_frames.append(frame)
|
||||||
loaded_ts.append(pts.item())
|
loaded_ts.append(pts.item())
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
@@ -237,10 +289,14 @@ def decode_video_frames_torchcodec(
|
|||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"{closest_ts=}")
|
logging.info(f"{closest_ts=}")
|
||||||
|
|
||||||
# convert to float32 in [0,1] range (channel first)
|
# convert to float32 in [0,1] range
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||||
|
|
||||||
|
if not len(timestamps) == len(closest_frames):
|
||||||
|
raise FrameTimestampError(
|
||||||
|
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||||
|
)
|
||||||
|
|
||||||
assert len(timestamps) == len(closest_frames)
|
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -179,10 +179,11 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
dataset,
|
dataset,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle and not cfg.dataset.streaming,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type == "cuda",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
|
prefetch_factor=2,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
@@ -208,6 +209,9 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
if batch[key].dtype != torch.bool:
|
||||||
|
batch[key] = batch[key].type(torch.float32) if device.type == "mps" else batch[key]
|
||||||
|
|
||||||
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||||
|
|
||||||
train_tracker, output_dict = update_policy(
|
train_tracker, output_dict = update_policy(
|
||||||
|
|||||||
391
tests/datasets/test_streaming.py
Normal file
391
tests/datasets/test_streaming.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
|
||||||
|
from lerobot.datasets.utils import safe_shard
|
||||||
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||||
|
|
||||||
|
|
||||||
|
def get_frames_expected_order(streaming_ds: StreamingLeRobotDataset) -> list[int]:
|
||||||
|
"""Replicates the shuffling logic of StreamingLeRobotDataset to get the expected order of indices."""
|
||||||
|
rng = np.random.default_rng(streaming_ds.seed)
|
||||||
|
buffer_size = streaming_ds.buffer_size
|
||||||
|
num_shards = streaming_ds.num_shards
|
||||||
|
|
||||||
|
shards_indices = []
|
||||||
|
for shard_idx in range(num_shards):
|
||||||
|
shard = streaming_ds.hf_dataset.shard(num_shards, index=shard_idx)
|
||||||
|
shard_indices = [item["index"] for item in shard]
|
||||||
|
shards_indices.append(shard_indices)
|
||||||
|
|
||||||
|
shard_iterators = {i: iter(s) for i, s in enumerate(shards_indices)}
|
||||||
|
|
||||||
|
buffer_indices_generator = streaming_ds._iter_random_indices(rng, buffer_size)
|
||||||
|
|
||||||
|
frames_buffer = []
|
||||||
|
expected_indices = []
|
||||||
|
|
||||||
|
while shard_iterators: # While there are still available shards
|
||||||
|
available_shard_keys = list(shard_iterators.keys())
|
||||||
|
if not available_shard_keys:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Call _infinite_generator_over_elements with current available shards (key difference!)
|
||||||
|
shard_key = next(streaming_ds._infinite_generator_over_elements(rng, available_shard_keys))
|
||||||
|
|
||||||
|
try:
|
||||||
|
frame_index = next(shard_iterators[shard_key])
|
||||||
|
|
||||||
|
if len(frames_buffer) == buffer_size:
|
||||||
|
i = next(buffer_indices_generator)
|
||||||
|
expected_indices.append(frames_buffer[i])
|
||||||
|
frames_buffer[i] = frame_index
|
||||||
|
else:
|
||||||
|
frames_buffer.append(frame_index)
|
||||||
|
|
||||||
|
except StopIteration:
|
||||||
|
del shard_iterators[shard_key] # Remove exhausted shard
|
||||||
|
|
||||||
|
rng.shuffle(frames_buffer)
|
||||||
|
expected_indices.extend(frames_buffer)
|
||||||
|
|
||||||
|
return expected_indices
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_frame_consistency(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""Test if are correctly accessed"""
|
||||||
|
ds_num_frames = 400
|
||||||
|
ds_num_episodes = 10
|
||||||
|
buffer_size = 100
|
||||||
|
|
||||||
|
local_path = tmp_path / "test"
|
||||||
|
repo_id = f"{DUMMY_REPO_ID}"
|
||||||
|
|
||||||
|
ds = lerobot_dataset_factory(
|
||||||
|
root=local_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
total_episodes=ds_num_episodes,
|
||||||
|
total_frames=ds_num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_ds = iter(StreamingLeRobotDataset(repo_id=repo_id, root=local_path, buffer_size=buffer_size))
|
||||||
|
|
||||||
|
key_checks = []
|
||||||
|
for _ in range(ds_num_frames):
|
||||||
|
streaming_frame = next(streaming_ds)
|
||||||
|
frame_idx = streaming_frame["index"]
|
||||||
|
target_frame = ds[frame_idx]
|
||||||
|
|
||||||
|
for key in streaming_frame:
|
||||||
|
left = streaming_frame[key]
|
||||||
|
right = target_frame[key]
|
||||||
|
|
||||||
|
if isinstance(left, str):
|
||||||
|
check = left == right
|
||||||
|
|
||||||
|
elif isinstance(left, torch.Tensor):
|
||||||
|
check = torch.allclose(left, right) and left.shape == right.shape
|
||||||
|
|
||||||
|
elif isinstance(left, float):
|
||||||
|
check = left == right.item() # right is a torch.Tensor
|
||||||
|
|
||||||
|
key_checks.append((key, check))
|
||||||
|
|
||||||
|
assert all(t[1] for t in key_checks), (
|
||||||
|
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (frame_idx: {frame_idx})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shuffle",
|
||||||
|
[False, True],
|
||||||
|
)
|
||||||
|
def test_frames_order_over_epochs(tmp_path, lerobot_dataset_factory, shuffle):
|
||||||
|
"""Test if streamed frames correspond to shuffling operations over in-memory dataset."""
|
||||||
|
ds_num_frames = 400
|
||||||
|
ds_num_episodes = 10
|
||||||
|
buffer_size = 100
|
||||||
|
seed = 42
|
||||||
|
n_epochs = 3
|
||||||
|
|
||||||
|
local_path = tmp_path / "test"
|
||||||
|
repo_id = f"{DUMMY_REPO_ID}"
|
||||||
|
|
||||||
|
lerobot_dataset_factory(
|
||||||
|
root=local_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
total_episodes=ds_num_episodes,
|
||||||
|
total_frames=ds_num_frames,
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_ds = StreamingLeRobotDataset(
|
||||||
|
repo_id=repo_id, root=local_path, buffer_size=buffer_size, seed=seed, shuffle=shuffle
|
||||||
|
)
|
||||||
|
|
||||||
|
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||||
|
expected_indices = get_frames_expected_order(streaming_ds)
|
||||||
|
|
||||||
|
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||||
|
|
||||||
|
expected_indices = get_frames_expected_order(streaming_ds)
|
||||||
|
for _ in range(n_epochs):
|
||||||
|
streaming_indices = [frame["index"] for frame in streaming_ds]
|
||||||
|
frames_match = all(
|
||||||
|
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
assert not frames_match
|
||||||
|
else:
|
||||||
|
assert frames_match
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"shuffle",
|
||||||
|
[False, True],
|
||||||
|
)
|
||||||
|
def test_frames_order_with_shards(tmp_path, lerobot_dataset_factory, shuffle):
|
||||||
|
"""Test if streamed frames correspond to shuffling operations over in-memory dataset with multiple shards."""
|
||||||
|
ds_num_frames = 100
|
||||||
|
ds_num_episodes = 10
|
||||||
|
buffer_size = 10
|
||||||
|
|
||||||
|
seed = 42
|
||||||
|
n_epochs = 3
|
||||||
|
data_file_size_mb = 0.001
|
||||||
|
|
||||||
|
chunks_size = 1
|
||||||
|
|
||||||
|
local_path = tmp_path / "test"
|
||||||
|
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||||
|
|
||||||
|
lerobot_dataset_factory(
|
||||||
|
root=local_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
total_episodes=ds_num_episodes,
|
||||||
|
total_frames=ds_num_frames,
|
||||||
|
data_files_size_in_mb=data_file_size_mb,
|
||||||
|
chunks_size=chunks_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_ds = StreamingLeRobotDataset(
|
||||||
|
repo_id=repo_id,
|
||||||
|
root=local_path,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
seed=seed,
|
||||||
|
shuffle=shuffle,
|
||||||
|
max_num_shards=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
first_epoch_indices = [frame["index"] for frame in streaming_ds]
|
||||||
|
expected_indices = get_frames_expected_order(streaming_ds)
|
||||||
|
|
||||||
|
assert first_epoch_indices == expected_indices, "First epoch indices do not match expected indices"
|
||||||
|
|
||||||
|
for _ in range(n_epochs):
|
||||||
|
streaming_indices = [
|
||||||
|
frame["index"] for frame in streaming_ds
|
||||||
|
] # NOTE: this is the same as first_epoch_indices
|
||||||
|
frames_match = all(
|
||||||
|
s_index == e_index for s_index, e_index in zip(streaming_indices, expected_indices, strict=True)
|
||||||
|
)
|
||||||
|
if shuffle:
|
||||||
|
assert not frames_match
|
||||||
|
else:
|
||||||
|
assert frames_match
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"state_deltas, action_deltas",
|
||||||
|
[
|
||||||
|
([-1, -0.5, -0.20, 0], [0, 1, 2, 3]),
|
||||||
|
([-1, -0.5, -0.20, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||||
|
([-2, -1, -0.5, 0], [0, 1, 2, 3]),
|
||||||
|
([-2, -1, -0.5, 0], [-1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_frames_with_delta_consistency(tmp_path, lerobot_dataset_factory, state_deltas, action_deltas):
|
||||||
|
ds_num_frames = 500
|
||||||
|
ds_num_episodes = 10
|
||||||
|
buffer_size = 100
|
||||||
|
|
||||||
|
seed = 42
|
||||||
|
|
||||||
|
local_path = tmp_path / "test"
|
||||||
|
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||||
|
camera_key = "phone"
|
||||||
|
|
||||||
|
delta_timestamps = {
|
||||||
|
camera_key: state_deltas,
|
||||||
|
"state": state_deltas,
|
||||||
|
"action": action_deltas,
|
||||||
|
}
|
||||||
|
|
||||||
|
ds = lerobot_dataset_factory(
|
||||||
|
root=local_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
total_episodes=ds_num_episodes,
|
||||||
|
total_frames=ds_num_frames,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
|
streaming_ds = iter(
|
||||||
|
StreamingLeRobotDataset(
|
||||||
|
repo_id=repo_id,
|
||||||
|
root=local_path,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
seed=seed,
|
||||||
|
shuffle=False,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(ds_num_frames):
|
||||||
|
streaming_frame = next(streaming_ds)
|
||||||
|
frame_idx = streaming_frame["index"]
|
||||||
|
target_frame = ds[frame_idx]
|
||||||
|
|
||||||
|
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
||||||
|
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
key_checks = []
|
||||||
|
for key in streaming_frame:
|
||||||
|
left = streaming_frame[key]
|
||||||
|
right = target_frame[key]
|
||||||
|
|
||||||
|
if isinstance(left, str):
|
||||||
|
check = left == right
|
||||||
|
|
||||||
|
elif isinstance(left, torch.Tensor):
|
||||||
|
if (
|
||||||
|
key not in ds.meta.camera_keys
|
||||||
|
and "is_pad" not in key
|
||||||
|
and f"{key}_is_pad" in streaming_frame
|
||||||
|
):
|
||||||
|
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
||||||
|
left = left[~streaming_frame[f"{key}_is_pad"]]
|
||||||
|
right = right[~target_frame[f"{key}_is_pad"]]
|
||||||
|
|
||||||
|
check = torch.allclose(left, right) and left.shape == right.shape
|
||||||
|
|
||||||
|
key_checks.append((key, check))
|
||||||
|
|
||||||
|
assert all(t[1] for t in key_checks), (
|
||||||
|
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"state_deltas, action_deltas",
|
||||||
|
[
|
||||||
|
([-1, -0.5, -0.20, 0], [0, 1, 2, 3, 10, 20]),
|
||||||
|
([-1, -0.5, -0.20, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||||
|
([-2, -1, -0.5, 0], [0, 1, 2, 3, 10, 20]),
|
||||||
|
([-2, -1, -0.5, 0], [-20, -1.5, -1, -0.5, -0.20, -0.10, 0]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_frames_with_delta_consistency_with_shards(
|
||||||
|
tmp_path, lerobot_dataset_factory, state_deltas, action_deltas
|
||||||
|
):
|
||||||
|
ds_num_frames = 100
|
||||||
|
ds_num_episodes = 10
|
||||||
|
buffer_size = 10
|
||||||
|
data_file_size_mb = 0.001
|
||||||
|
chunks_size = 1
|
||||||
|
|
||||||
|
seed = 42
|
||||||
|
|
||||||
|
local_path = tmp_path / "test"
|
||||||
|
repo_id = f"{DUMMY_REPO_ID}-ciao"
|
||||||
|
camera_key = "phone"
|
||||||
|
|
||||||
|
delta_timestamps = {
|
||||||
|
camera_key: state_deltas,
|
||||||
|
"state": state_deltas,
|
||||||
|
"action": action_deltas,
|
||||||
|
}
|
||||||
|
|
||||||
|
ds = lerobot_dataset_factory(
|
||||||
|
root=local_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
total_episodes=ds_num_episodes,
|
||||||
|
total_frames=ds_num_frames,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
data_files_size_in_mb=data_file_size_mb,
|
||||||
|
chunks_size=chunks_size,
|
||||||
|
)
|
||||||
|
streaming_ds = StreamingLeRobotDataset(
|
||||||
|
repo_id=repo_id,
|
||||||
|
root=local_path,
|
||||||
|
buffer_size=buffer_size,
|
||||||
|
seed=seed,
|
||||||
|
shuffle=False,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
max_num_shards=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
iter(streaming_ds)
|
||||||
|
|
||||||
|
num_shards = 4
|
||||||
|
shards_indices = []
|
||||||
|
for shard_idx in range(num_shards):
|
||||||
|
shard = safe_shard(streaming_ds.hf_dataset, shard_idx, num_shards)
|
||||||
|
shard_indices = [item["index"] for item in shard]
|
||||||
|
shards_indices.append(shard_indices)
|
||||||
|
|
||||||
|
streaming_ds = iter(streaming_ds)
|
||||||
|
|
||||||
|
for i in range(ds_num_frames):
|
||||||
|
streaming_frame = next(streaming_ds)
|
||||||
|
frame_idx = streaming_frame["index"]
|
||||||
|
target_frame = ds[frame_idx]
|
||||||
|
|
||||||
|
assert set(streaming_frame.keys()) == set(target_frame.keys()), (
|
||||||
|
f"Keys differ between streaming frame and target one. Differ at: {set(streaming_frame.keys()) - set(target_frame.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
key_checks = []
|
||||||
|
for key in streaming_frame:
|
||||||
|
left = streaming_frame[key]
|
||||||
|
right = target_frame[key]
|
||||||
|
|
||||||
|
if isinstance(left, str):
|
||||||
|
check = left == right
|
||||||
|
|
||||||
|
elif isinstance(left, torch.Tensor):
|
||||||
|
if (
|
||||||
|
key not in ds.meta.camera_keys
|
||||||
|
and "is_pad" not in key
|
||||||
|
and f"{key}_is_pad" in streaming_frame
|
||||||
|
):
|
||||||
|
# comparing frames only on non-padded regions. Padding is applied to last-valid broadcasting
|
||||||
|
left = left[~streaming_frame[f"{key}_is_pad"]]
|
||||||
|
right = right[~target_frame[f"{key}_is_pad"]]
|
||||||
|
|
||||||
|
check = torch.allclose(left, right) and left.shape == right.shape
|
||||||
|
|
||||||
|
elif isinstance(left, float):
|
||||||
|
check = left == right.item() # right is a torch.Tensor
|
||||||
|
|
||||||
|
key_checks.append((key, check))
|
||||||
|
|
||||||
|
assert all(t[1] for t in key_checks), (
|
||||||
|
f"Checking {list(filter(lambda t: not t[1], key_checks))[0][0]} left and right were found different (i: {i}, frame_idx: {frame_idx})"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user