Update LeRobotDataset.__get_item__

This commit is contained in:
Simon Alibert
2024-10-10 21:32:14 +02:00
parent 3113038beb
commit b417cebc4e
3 changed files with 232 additions and 128 deletions

View File

@@ -15,25 +15,27 @@
# limitations under the License.
import logging
import os
from itertools import accumulate
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from huggingface_hub import snapshot_download
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
download_episodes,
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
get_episode_data_index,
get_hub_safe_version,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
load_tasks,
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.0"
@@ -49,6 +51,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
video_backend: str | None = None,
):
super().__init__()
@@ -58,7 +61,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend is not None else "pyav"
self.delta_indices = None
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
@@ -68,34 +73,60 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tasks = load_tasks(repo_id, self._version, self.root)
# Load actual data
download_episodes(
repo_id,
self._version,
self.root,
self.data_path,
self.video_keys,
self.num_episodes,
self.episodes,
self.videos_path,
)
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
self.episode_data_index = self.get_episode_data_index()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# TODO(aliberts):
# - [ ] Update __get_item__
# - [X] Move delta_timestamp logic outside __get_item__
# - [X] Update __get_item__
# - [ ] Add self.add_frame()
# - [ ] Add self.consolidate() for:
# - [X] Check timestamps sync
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
# TODO(aliberts): remove (deprecated)
# if split == "train":
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
# else:
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
# self.hf_dataset = reset_episode_index(self.hf_dataset)
# if self.video:
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
def download_episodes(self) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
in 'local_dir', they won't be downloaded again.
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
"""
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
if self.episodes is not None:
files = [
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
for ep_idx in self.episodes
]
if len(self.video_keys) > 0:
video_files = [
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
for vid_key in self.video_keys
for ep_idx in self.episodes
]
files += video_files
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
local_dir=self.root,
allow_patterns=files,
)
@property
def data_path(self) -> str:
@@ -134,17 +165,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video streams from cameras."""
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
"""
DEPRECATED, USE 'video_keys' INSTEAD
Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
# TODO(aliberts): remove
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame):
@@ -166,54 +200,97 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
# @property
# def tolerance_s(self) -> float:
# """Tolerance in seconds used to discard loaded frames when their timestamps
# are not close enough from the requested frames. It is used at the init of the dataset to make sure
# that each timestamps is separated to the next by 1/fps +/- tolerance. It is only used when
# `delta_timestamps` is provided or when loading video frames from mp4 files.
# """
# # 1e-4 to account for possible numerical error
# return 1e-4
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
def current_episode_index(self, idx: int) -> int:
episode_index = self.hf_dataset["episode_index"][idx]
if self.episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
# get episode_index from selected episodes
episode_index = self.episodes.index(episode_index)
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return episode_index
def episode_length(self, episode_index) -> int:
"""Number of samples/frames for given episode."""
return self.info["episodes"][episode_index]["length"]
def _get_query_indices(self, idx: int, ep_idx: int) -> dict[str, list[int]]:
# Pad values outside of current episode range
ep_start = self.episode_data_index["from"][ep_idx]
ep_end = self.episode_data_index["to"][ep_idx]
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
for key, delta_idx in self.delta_indices.items()
}
def _get_query_timestamps(
self, query_indices: dict[str, list[int]], current_ts: float
) -> dict[str, list[float]]:
query_timestamps = {}
for key in self.video_keys:
if key in query_indices:
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
query_timestamps[key] = torch.stack(timestamps).tolist()
else:
query_timestamps[key] = [current_ts]
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
the main process and a subprocess fails to access it.
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
item[vid_key] = frames
return item
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
def __getitem__(self, idx) -> dict:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
self.tolerance_s,
)
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
for key, val in query_result.items():
item[key] = val
if self.video:
item = load_from_videos(
item,
self.video_keys,
self.videos_dir,
self.tolerance_s,
self.video_backend,
)
if len(self.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(query_indices, current_ts)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.image_transforms is not None:
for cam in self.camera_keys: