Update LeRobotDataset.__get_item__

This commit is contained in:
Simon Alibert
2024-10-10 21:32:14 +02:00
parent 3113038beb
commit b417cebc4e
3 changed files with 232 additions and 128 deletions

View File

@@ -16,13 +16,15 @@
import json
import warnings
from functools import cache
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from typing import Dict
import datasets
import torch
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
from PIL import Image as PILImage
from torchvision import transforms
@@ -193,40 +195,102 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
return json.load(f)
def download_episodes(
repo_id: str,
version: str,
local_dir: Path,
data_path: str,
video_keys: list,
total_episodes: int,
episodes: list[int] | None = None,
videos_path: str | None = None,
) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
if episodes is not None:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
if len(video_keys) > 0:
video_files = [
videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in video_keys
for ep_idx in episodes
]
files += video_files
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
snapshot_download(
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
)
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
# timestamps[2] += tolerance_s # TODO delete
# timestamps[-2] += tolerance_s/2 # TODO delete
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
# We mask differences between the timestamp at the end of an episode
# and the one the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
return True
def check_delta_timestamps(
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
) -> bool:
outside_tolerance = {}
for key, delta_ts in delta_timestamps.items():
abs_delta_ts = torch.abs(torch.tensor(delta_ts))
within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s
if not torch.all(within_tolerance):
outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance]
if len(outside_tolerance) > 0:
if raise_value_error:
raise ValueError(
f"""
The following delta_timestamps are found outside of tolerance range.
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
their values accordingly.
\n{pformat(outside_tolerance)}
"""
)
return False
return True
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
delta_indices = {}
for key, delta_ts in delta_timestamps.items():
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
return delta_indices
def load_previous_and_future_frames(