Rework LeRobotDataset.__init__
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
@@ -24,27 +25,27 @@ import torch.utils
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
load_episode_data_index,
|
||||
download_episodes,
|
||||
get_hub_safe_version,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
load_videos,
|
||||
reset_episode_index,
|
||||
load_tasks,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v1.6"
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = DATA_DIR,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
@@ -52,24 +53,64 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.episodes = episodes
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||
self.info = load_info(repo_id, self._version, self.root)
|
||||
self.stats = load_stats(repo_id, self._version, self.root)
|
||||
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||
|
||||
# Load actual data
|
||||
download_episodes(
|
||||
repo_id,
|
||||
self._version,
|
||||
self.root,
|
||||
self.data_path,
|
||||
self.video_keys,
|
||||
self.num_episodes,
|
||||
self.episodes,
|
||||
self.videos_path,
|
||||
)
|
||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||
self.episode_data_index = self.get_episode_data_index()
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] Update __get_item__
|
||||
# - [ ] Add self.consolidate() for:
|
||||
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
# TODO(aliberts): remove (deprecated)
|
||||
# if split == "train":
|
||||
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
|
||||
# else:
|
||||
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
# self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
# if self.video:
|
||||
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def videos_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||
|
||||
@property
|
||||
def episode_dicts(self) -> list[dict]:
|
||||
"""List of dictionary containing information for each episode, indexed by episode_index."""
|
||||
return self.info["episodes"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
@@ -77,24 +118,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
Returns False if it only loads images from png files.
|
||||
"""
|
||||
return self.info.get("video", False)
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
return self.info["keys"]
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
return self.hf_dataset.features
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return self.info["image_keys"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return self.info["video_keys"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
"""Keys to access image and video streams from cameras."""
|
||||
return self.image_keys + self.video_keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
@@ -117,8 +158,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
"""Number of episodes selected."""
|
||||
return len(self.episodes) if self.episodes is not None else self.total_episodes
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
@@ -129,6 +175,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
self.info.get("shapes")
|
||||
|
||||
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
|
||||
if self.episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
@@ -147,7 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if self.video:
|
||||
item = load_from_videos(
|
||||
item,
|
||||
self.video_frame_keys,
|
||||
self.video_keys,
|
||||
self.videos_dir,
|
||||
self.tolerance_s,
|
||||
self.video_backend,
|
||||
@@ -225,7 +287,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: Path | None = DATA_DIR,
|
||||
root: Path | None = LEROBOT_HOME,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
|
||||
Reference in New Issue
Block a user