Update LeRobotDataset.__get_item__
This commit is contained in:
@@ -16,13 +16,15 @@
|
||||
import json
|
||||
import warnings
|
||||
from functools import cache
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
@@ -193,40 +195,102 @@ def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def download_episodes(
|
||||
repo_id: str,
|
||||
version: str,
|
||||
local_dir: Path,
|
||||
data_path: str,
|
||||
video_keys: list,
|
||||
total_episodes: int,
|
||||
episodes: list[int] | None = None,
|
||||
videos_path: str | None = None,
|
||||
) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
in 'local_dir', they won't be downloaded again.
|
||||
|
||||
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
||||
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
||||
"""
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||
if len(video_keys) > 0:
|
||||
video_files = [
|
||||
videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
for vid_key in video_keys
|
||||
for ep_idx in episodes
|
||||
]
|
||||
files += video_files
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
snapshot_download(
|
||||
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
|
||||
)
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
# timestamps[2] += tolerance_s # TODO delete
|
||||
# timestamps[-2] += tolerance_s/2 # TODO delete
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
abs_delta_ts = torch.abs(torch.tensor(delta_ts))
|
||||
within_tolerance = (abs_delta_ts % (1 / fps)) <= tolerance_s
|
||||
if not torch.all(within_tolerance):
|
||||
outside_tolerance[key] = torch.tensor(delta_ts)[~within_tolerance]
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
|
||||
Reference in New Issue
Block a user