Rework LeRobotDataset.__init__

This commit is contained in:
Simon Alibert
2024-10-09 14:33:26 +02:00
parent 2d75b93ba0
commit 096824b5ff
2 changed files with 189 additions and 114 deletions

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import logging
import os
from itertools import accumulate
from pathlib import Path
from typing import Callable
@@ -24,27 +25,27 @@ import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
download_episodes,
get_hub_safe_version,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
load_videos,
reset_episode_index,
load_tasks,
)
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v1.6"
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_id: str,
root: Path | None = DATA_DIR,
root: Path | None = None,
episodes: list[int] | None = None,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
@@ -52,24 +53,64 @@ class LeRobotDataset(torch.utils.data.Dataset):
):
super().__init__()
self.repo_id = repo_id
self.root = root
self.root = root if root is not None else LEROBOT_HOME / repo_id
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
if split == "train":
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
else:
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
self.hf_dataset = reset_episode_index(self.hf_dataset)
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
self.info = load_info(repo_id, CODEBASE_VERSION, root)
if self.video:
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
self.video_backend = video_backend if video_backend is not None else "pyav"
self.episodes = episodes
self.video_backend = video_backend if video_backend is not None else "pyav"
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.info = load_info(repo_id, self._version, self.root)
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
# Load actual data
download_episodes(
repo_id,
self._version,
self.root,
self.data_path,
self.video_keys,
self.num_episodes,
self.episodes,
self.videos_path,
)
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
self.episode_data_index = self.get_episode_data_index()
# TODO(aliberts):
# - [ ] Update __get_item__
# - [ ] Add self.consolidate() for:
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
# TODO(aliberts): remove (deprecated)
# if split == "train":
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
# else:
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
# self.hf_dataset = reset_episode_index(self.hf_dataset)
# if self.video:
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
@property
def data_path(self) -> str:
"""Formattable string for the parquet files."""
return self.info["data_path"]
@property
def videos_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def episode_dicts(self) -> list[dict]:
"""List of dictionary containing information for each episode, indexed by episode_index."""
return self.info["episodes"]
@property
def fps(self) -> int:
@@ -77,24 +118,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
"""
return self.info.get("video", False)
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
return self.info["keys"]
@property
def features(self) -> datasets.Features:
return self.hf_dataset.features
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return self.info["image_keys"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return self.info["video_keys"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.hf_dataset.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
"""Keys to access image and video streams from cameras."""
return self.image_keys + self.video_keys
@property
def video_frame_keys(self) -> list[str]:
@@ -117,8 +158,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return len(self.hf_dataset.unique("episode_index"))
"""Number of episodes selected."""
return len(self.episodes) if self.episodes is not None else self.total_episodes
@property
def total_episodes(self) -> int:
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def tolerance_s(self) -> float:
@@ -129,6 +175,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
self.info.get("shapes")
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
if self.episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
cumulative_lenghts = list(accumulate(episode_lengths.values()))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def __len__(self):
return self.num_samples
@@ -147,7 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.video:
item = load_from_videos(
item,
self.video_frame_keys,
self.video_keys,
self.videos_dir,
self.tolerance_s,
self.video_backend,
@@ -225,7 +287,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_ids: list[str],
root: Path | None = DATA_DIR,
root: Path | None = LEROBOT_HOME,
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,