|
|
|
|
@@ -45,7 +45,7 @@ from lerobot.common.datasets.utils import (
|
|
|
|
|
get_episode_data_index,
|
|
|
|
|
get_hub_safe_version,
|
|
|
|
|
hf_transform_to_torch,
|
|
|
|
|
load_episode_dicts,
|
|
|
|
|
load_episodes,
|
|
|
|
|
load_info,
|
|
|
|
|
load_stats,
|
|
|
|
|
load_tasks,
|
|
|
|
|
@@ -66,6 +66,237 @@ CODEBASE_VERSION = "v2.0"
|
|
|
|
|
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LeRobotDatasetMetadata:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
repo_id: str,
|
|
|
|
|
root: Path | None = None,
|
|
|
|
|
local_files_only: bool = False,
|
|
|
|
|
):
|
|
|
|
|
self.repo_id = repo_id
|
|
|
|
|
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
|
|
|
|
self.local_files_only = local_files_only
|
|
|
|
|
|
|
|
|
|
# Load metadata
|
|
|
|
|
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
|
|
|
|
self.pull_from_repo(allow_patterns="meta/")
|
|
|
|
|
self.info = load_info(self.root)
|
|
|
|
|
self.stats = load_stats(self.root)
|
|
|
|
|
self.tasks = load_tasks(self.root)
|
|
|
|
|
self.episodes = load_episodes(self.root)
|
|
|
|
|
|
|
|
|
|
def pull_from_repo(
|
|
|
|
|
self,
|
|
|
|
|
allow_patterns: list[str] | str | None = None,
|
|
|
|
|
ignore_patterns: list[str] | str | None = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
snapshot_download(
|
|
|
|
|
self.repo_id,
|
|
|
|
|
repo_type="dataset",
|
|
|
|
|
revision=self._hub_version,
|
|
|
|
|
local_dir=self.root,
|
|
|
|
|
allow_patterns=allow_patterns,
|
|
|
|
|
ignore_patterns=ignore_patterns,
|
|
|
|
|
local_files_only=self.local_files_only,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def _hub_version(self) -> str | None:
|
|
|
|
|
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _version(self) -> str:
|
|
|
|
|
"""Codebase version used to create this dataset."""
|
|
|
|
|
return self.info["codebase_version"]
|
|
|
|
|
|
|
|
|
|
def get_data_file_path(self, ep_index: int) -> Path:
|
|
|
|
|
ep_chunk = self.get_episode_chunk(ep_index)
|
|
|
|
|
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
|
|
|
|
return Path(fpath)
|
|
|
|
|
|
|
|
|
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
|
|
|
|
ep_chunk = self.get_episode_chunk(ep_index)
|
|
|
|
|
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
|
|
|
|
return Path(fpath)
|
|
|
|
|
|
|
|
|
|
def get_episode_chunk(self, ep_index: int) -> int:
|
|
|
|
|
return ep_index // self.chunks_size
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def data_path(self) -> str:
|
|
|
|
|
"""Formattable string for the parquet files."""
|
|
|
|
|
return self.info["data_path"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def videos_path(self) -> str | None:
|
|
|
|
|
"""Formattable string for the video files."""
|
|
|
|
|
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def fps(self) -> int:
|
|
|
|
|
"""Frames per second used during data collection."""
|
|
|
|
|
return self.info["fps"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access non-image data (state, actions etc.)."""
|
|
|
|
|
return self.info["keys"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access visual modalities stored as images."""
|
|
|
|
|
return self.info["image_keys"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def video_keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access visual modalities stored as videos."""
|
|
|
|
|
return self.info["video_keys"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def camera_keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
|
|
|
|
return self.image_keys + self.video_keys
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def names(self) -> dict[list[str]]:
|
|
|
|
|
"""Names of the various dimensions of vector modalities."""
|
|
|
|
|
return self.info["names"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_episodes(self) -> int:
|
|
|
|
|
"""Total number of episodes available."""
|
|
|
|
|
return self.info["total_episodes"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_frames(self) -> int:
|
|
|
|
|
"""Total number of frames saved in this dataset."""
|
|
|
|
|
return self.info["total_frames"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_tasks(self) -> int:
|
|
|
|
|
"""Total number of different tasks performed in this dataset."""
|
|
|
|
|
return self.info["total_tasks"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_chunks(self) -> int:
|
|
|
|
|
"""Total number of chunks (groups of episodes)."""
|
|
|
|
|
return self.info["total_chunks"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def chunks_size(self) -> int:
|
|
|
|
|
"""Max number of episodes per chunk."""
|
|
|
|
|
return self.info["chunks_size"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def shapes(self) -> dict:
|
|
|
|
|
"""Shapes for the different features."""
|
|
|
|
|
return self.info["shapes"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def task_to_task_index(self) -> dict:
|
|
|
|
|
return {task: task_idx for task_idx, task in self.tasks.items()}
|
|
|
|
|
|
|
|
|
|
def get_task_index(self, task: str) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
|
|
|
|
otherwise creates a new task_index.
|
|
|
|
|
"""
|
|
|
|
|
task_index = self.task_to_task_index.get(task, None)
|
|
|
|
|
return task_index if task_index is not None else self.total_tasks
|
|
|
|
|
|
|
|
|
|
def add_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
|
|
|
|
|
self.info["total_episodes"] += 1
|
|
|
|
|
self.info["total_frames"] += episode_length
|
|
|
|
|
|
|
|
|
|
if task_index not in self.tasks:
|
|
|
|
|
self.info["total_tasks"] += 1
|
|
|
|
|
self.tasks[task_index] = task
|
|
|
|
|
task_dict = {
|
|
|
|
|
"task_index": task_index,
|
|
|
|
|
"task": task,
|
|
|
|
|
}
|
|
|
|
|
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
|
|
|
|
|
|
|
|
|
chunk = self.get_episode_chunk(episode_index)
|
|
|
|
|
if chunk >= self.total_chunks:
|
|
|
|
|
self.info["total_chunks"] += 1
|
|
|
|
|
|
|
|
|
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
|
|
|
|
self.info["total_videos"] += len(self.video_keys)
|
|
|
|
|
write_json(self.info, self.root / INFO_PATH)
|
|
|
|
|
|
|
|
|
|
episode_dict = {
|
|
|
|
|
"episode_index": episode_index,
|
|
|
|
|
"tasks": [task],
|
|
|
|
|
"length": episode_length,
|
|
|
|
|
}
|
|
|
|
|
self.episodes.append(episode_dict)
|
|
|
|
|
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
|
|
|
|
|
|
|
|
|
def write_video_info(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
|
|
|
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
|
|
|
|
"""
|
|
|
|
|
for key in self.video_keys:
|
|
|
|
|
if key not in self.info["videos"]:
|
|
|
|
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
|
|
|
|
self.info["videos"][key] = get_video_info(video_path)
|
|
|
|
|
|
|
|
|
|
write_json(self.info, self.root / INFO_PATH)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def create(
|
|
|
|
|
cls,
|
|
|
|
|
repo_id: str,
|
|
|
|
|
fps: int,
|
|
|
|
|
root: Path | None = None,
|
|
|
|
|
robot: Robot | None = None,
|
|
|
|
|
robot_type: str | None = None,
|
|
|
|
|
keys: list[str] | None = None,
|
|
|
|
|
image_keys: list[str] | None = None,
|
|
|
|
|
video_keys: list[str] = None,
|
|
|
|
|
shapes: dict | None = None,
|
|
|
|
|
names: dict | None = None,
|
|
|
|
|
use_videos: bool = True,
|
|
|
|
|
) -> "LeRobotDatasetMetadata":
|
|
|
|
|
"""Creates metadata for a LeRobotDataset."""
|
|
|
|
|
obj = cls.__new__(cls)
|
|
|
|
|
obj.repo_id = repo_id
|
|
|
|
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
|
|
|
|
obj.image_writer = None
|
|
|
|
|
|
|
|
|
|
if robot is not None:
|
|
|
|
|
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
|
|
|
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
|
|
|
|
logging.warning(
|
|
|
|
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
|
|
|
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
|
|
|
|
)
|
|
|
|
|
elif (
|
|
|
|
|
robot_type is None
|
|
|
|
|
or keys is None
|
|
|
|
|
or image_keys is None
|
|
|
|
|
or video_keys is None
|
|
|
|
|
or shapes is None
|
|
|
|
|
or names is None
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(video_keys) > 0 and not use_videos:
|
|
|
|
|
raise ValueError()
|
|
|
|
|
|
|
|
|
|
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
|
|
|
|
obj.info = create_empty_dataset_info(
|
|
|
|
|
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
|
|
|
|
)
|
|
|
|
|
write_json(obj.info, obj.root / INFO_PATH)
|
|
|
|
|
obj.local_files_only = True
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@@ -86,9 +317,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
- On your local disk in the 'root' folder. This is typically the case when you recorded your
|
|
|
|
|
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
|
|
|
|
|
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
|
|
|
|
|
internet connection).
|
|
|
|
|
internet connection), in that case, use local_files_only=True.
|
|
|
|
|
|
|
|
|
|
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and is not on
|
|
|
|
|
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
|
|
|
|
|
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
|
|
|
|
|
the dataset from that address and load it, pending your dataset is compliant with
|
|
|
|
|
codebase_version v2.0. If your dataset has been created before this new format, you will be
|
|
|
|
|
@@ -96,9 +327,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2. Your dataset doesn't already exists (either on local disk or on the Hub):
|
|
|
|
|
You can create an empty LeRobotDataset with the 'create' classmethod. This can be used for
|
|
|
|
|
recording a dataset or port an existing dataset to the LeRobotDataset format.
|
|
|
|
|
2. Your dataset doesn't already exists (either on local disk or on the Hub): you can create an empty
|
|
|
|
|
LeRobotDataset with the 'create' classmethod. This can be used for recording a dataset or port an
|
|
|
|
|
existing dataset to the LeRobotDataset format.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
In terms of files, LeRobotDataset encapsulates 3 main things:
|
|
|
|
|
@@ -192,21 +423,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
self.image_writer = None
|
|
|
|
|
self.episode_buffer = {}
|
|
|
|
|
|
|
|
|
|
# Load metadata
|
|
|
|
|
self.root.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
self.pull_from_repo(allow_patterns="meta/")
|
|
|
|
|
self.info = load_info(self.root)
|
|
|
|
|
self.stats = load_stats(self.root)
|
|
|
|
|
self.tasks = load_tasks(self.root)
|
|
|
|
|
self.episode_dicts = load_episode_dicts(self.root)
|
|
|
|
|
|
|
|
|
|
# Load metadata
|
|
|
|
|
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
|
|
|
|
|
|
|
|
|
|
# Check version
|
|
|
|
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
|
|
|
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
|
|
|
|
|
|
|
|
|
# Load actual data
|
|
|
|
|
self.download_episodes(download_videos)
|
|
|
|
|
self.hf_dataset = self.load_hf_dataset()
|
|
|
|
|
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
|
|
|
|
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
|
|
|
|
|
|
|
|
# Check timestamps
|
|
|
|
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
|
|
|
|
@@ -216,26 +444,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
|
|
|
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
|
|
|
|
|
|
|
|
|
# TODO(aliberts):
|
|
|
|
|
# - [X] Move delta_timestamp logic outside __get_item__
|
|
|
|
|
# - [X] Update __get_item__
|
|
|
|
|
# - [/] Add doc
|
|
|
|
|
# - [ ] Add self.add_frame()
|
|
|
|
|
# - [ ] Add self.consolidate() for:
|
|
|
|
|
# - [X] Check timestamps sync
|
|
|
|
|
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
|
|
|
|
# - [ ] Update episode_index (arg update=True)
|
|
|
|
|
# - [ ] Update info.json (arg update=True)
|
|
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
|
def _hub_version(self) -> str | None:
|
|
|
|
|
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _version(self) -> str:
|
|
|
|
|
"""Codebase version used to create this dataset."""
|
|
|
|
|
return self.info["codebase_version"]
|
|
|
|
|
|
|
|
|
|
def push_to_hub(self, push_videos: bool = True) -> None:
|
|
|
|
|
if not self.consolidated:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
@@ -262,7 +470,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
snapshot_download(
|
|
|
|
|
self.repo_id,
|
|
|
|
|
repo_type="dataset",
|
|
|
|
|
revision=self._hub_version,
|
|
|
|
|
revision=self.meta._hub_version,
|
|
|
|
|
local_dir=self.root,
|
|
|
|
|
allow_patterns=allow_patterns,
|
|
|
|
|
ignore_patterns=ignore_patterns,
|
|
|
|
|
@@ -280,11 +488,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
files = None
|
|
|
|
|
ignore_patterns = None if download_videos else "videos/"
|
|
|
|
|
if self.episodes is not None:
|
|
|
|
|
files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
|
|
|
if len(self.video_keys) > 0 and download_videos:
|
|
|
|
|
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
|
|
|
if len(self.meta.video_keys) > 0 and download_videos:
|
|
|
|
|
video_files = [
|
|
|
|
|
str(self.get_video_file_path(ep_idx, vid_key))
|
|
|
|
|
for vid_key in self.video_keys
|
|
|
|
|
str(self.meta.get_video_file_path(ep_idx, vid_key))
|
|
|
|
|
for vid_key in self.meta.video_keys
|
|
|
|
|
for ep_idx in self.episodes
|
|
|
|
|
]
|
|
|
|
|
files += video_files
|
|
|
|
|
@@ -297,108 +505,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
path = str(self.root / "data")
|
|
|
|
|
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
|
|
|
|
else:
|
|
|
|
|
files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
|
|
|
files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
|
|
|
|
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
|
|
|
|
|
|
|
|
|
hf_dataset.set_transform(hf_transform_to_torch)
|
|
|
|
|
return hf_dataset
|
|
|
|
|
|
|
|
|
|
def get_data_file_path(self, ep_index: int) -> Path:
|
|
|
|
|
ep_chunk = self.get_episode_chunk(ep_index)
|
|
|
|
|
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
|
|
|
|
return Path(fpath)
|
|
|
|
|
|
|
|
|
|
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
|
|
|
|
ep_chunk = self.get_episode_chunk(ep_index)
|
|
|
|
|
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
|
|
|
|
return Path(fpath)
|
|
|
|
|
|
|
|
|
|
def get_episode_chunk(self, ep_index: int) -> int:
|
|
|
|
|
return ep_index // self.chunks_size
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def data_path(self) -> str:
|
|
|
|
|
"""Formattable string for the parquet files."""
|
|
|
|
|
return self.info["data_path"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def videos_path(self) -> str | None:
|
|
|
|
|
"""Formattable string for the video files."""
|
|
|
|
|
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def fps(self) -> int:
|
|
|
|
|
"""Frames per second used during data collection."""
|
|
|
|
|
return self.info["fps"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access non-image data (state, actions etc.)."""
|
|
|
|
|
return self.info["keys"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access visual modalities stored as images."""
|
|
|
|
|
return self.info["image_keys"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def video_keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access visual modalities stored as videos."""
|
|
|
|
|
return self.info["video_keys"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def camera_keys(self) -> list[str]:
|
|
|
|
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
|
|
|
|
return self.image_keys + self.video_keys
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def names(self) -> dict[list[str]]:
|
|
|
|
|
"""Names of the various dimensions of vector modalities."""
|
|
|
|
|
return self.info["names"]
|
|
|
|
|
return self.meta.fps
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def num_frames(self) -> int:
|
|
|
|
|
"""Number of frames in selected episodes."""
|
|
|
|
|
return len(self.hf_dataset) if self.hf_dataset is not None else self.total_frames
|
|
|
|
|
return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def num_episodes(self) -> int:
|
|
|
|
|
"""Number of episodes selected."""
|
|
|
|
|
return len(self.episodes) if self.episodes is not None else self.total_episodes
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_episodes(self) -> int:
|
|
|
|
|
"""Total number of episodes available."""
|
|
|
|
|
return self.info["total_episodes"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_frames(self) -> int:
|
|
|
|
|
"""Total number of frames saved in this dataset."""
|
|
|
|
|
return self.info["total_frames"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_tasks(self) -> int:
|
|
|
|
|
"""Total number of different tasks performed in this dataset."""
|
|
|
|
|
return self.info["total_tasks"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def total_chunks(self) -> int:
|
|
|
|
|
"""Total number of chunks (groups of episodes)."""
|
|
|
|
|
return self.info["total_chunks"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def chunks_size(self) -> int:
|
|
|
|
|
"""Max number of episodes per chunk."""
|
|
|
|
|
return self.info["chunks_size"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def shapes(self) -> dict:
|
|
|
|
|
"""Shapes for the different features."""
|
|
|
|
|
return self.info["shapes"]
|
|
|
|
|
return len(self.episodes) if self.episodes is not None else self.meta.total_episodes
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def features(self) -> list[str]:
|
|
|
|
|
return list(self._features) + self.video_keys
|
|
|
|
|
return list(self._features) + self.meta.video_keys
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _features(self) -> datasets.Features:
|
|
|
|
|
@@ -418,39 +548,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
features[key] = datasets.Value(dtype="bool")
|
|
|
|
|
elif key in ["timestamp", "next.reward"]:
|
|
|
|
|
features[key] = datasets.Value(dtype="float32")
|
|
|
|
|
elif key in self.image_keys:
|
|
|
|
|
elif key in self.meta.image_keys:
|
|
|
|
|
features[key] = datasets.Image()
|
|
|
|
|
elif key in self.keys:
|
|
|
|
|
elif key in self.meta.keys:
|
|
|
|
|
features[key] = datasets.Sequence(
|
|
|
|
|
length=self.shapes[key], feature=datasets.Value(dtype="float32")
|
|
|
|
|
length=self.meta.shapes[key], feature=datasets.Value(dtype="float32")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return datasets.Features(features)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def task_to_task_index(self) -> dict:
|
|
|
|
|
return {task: task_idx for task_idx, task in self.tasks.items()}
|
|
|
|
|
|
|
|
|
|
def get_task_index(self, task: str) -> int:
|
|
|
|
|
"""
|
|
|
|
|
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
|
|
|
|
otherwise creates a new task_index.
|
|
|
|
|
"""
|
|
|
|
|
task_index = self.task_to_task_index.get(task, None)
|
|
|
|
|
return task_index if task_index is not None else self.total_tasks
|
|
|
|
|
|
|
|
|
|
def current_episode_index(self, idx: int) -> int:
|
|
|
|
|
episode_index = self.hf_dataset["episode_index"][idx]
|
|
|
|
|
if self.episodes is not None:
|
|
|
|
|
# get episode_index from selected episodes
|
|
|
|
|
episode_index = self.episodes.index(episode_index)
|
|
|
|
|
|
|
|
|
|
return episode_index
|
|
|
|
|
|
|
|
|
|
def episode_length(self, episode_index) -> int:
|
|
|
|
|
"""Number of samples/frames for given episode."""
|
|
|
|
|
return self.info["episodes"][episode_index]["length"]
|
|
|
|
|
|
|
|
|
|
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
|
|
|
|
ep_start = self.episode_data_index["from"][ep_idx]
|
|
|
|
|
ep_end = self.episode_data_index["to"][ep_idx]
|
|
|
|
|
@@ -472,7 +578,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
query_indices: dict[str, list[int]] | None = None,
|
|
|
|
|
) -> dict[str, list[float]]:
|
|
|
|
|
query_timestamps = {}
|
|
|
|
|
for key in self.video_keys:
|
|
|
|
|
for key in self.meta.video_keys:
|
|
|
|
|
if query_indices is not None and key in query_indices:
|
|
|
|
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
|
|
|
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
|
|
|
|
@@ -485,7 +591,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
return {
|
|
|
|
|
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
|
|
|
|
for key, q_idx in query_indices.items()
|
|
|
|
|
if key not in self.video_keys
|
|
|
|
|
if key not in self.meta.video_keys
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
|
|
|
|
@@ -496,7 +602,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
"""
|
|
|
|
|
item = {}
|
|
|
|
|
for vid_key, query_ts in query_timestamps.items():
|
|
|
|
|
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
|
|
|
|
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
|
|
|
|
frames = decode_video_frames_torchvision(
|
|
|
|
|
video_path, query_ts, self.tolerance_s, self.video_backend
|
|
|
|
|
)
|
|
|
|
|
@@ -525,14 +631,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
for key, val in query_result.items():
|
|
|
|
|
item[key] = val
|
|
|
|
|
|
|
|
|
|
if len(self.video_keys) > 0:
|
|
|
|
|
if len(self.meta.video_keys) > 0:
|
|
|
|
|
current_ts = item["timestamp"].item()
|
|
|
|
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
|
|
|
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
|
|
|
|
item = {**video_frames, **item}
|
|
|
|
|
|
|
|
|
|
if self.image_transforms is not None:
|
|
|
|
|
image_keys = self.camera_keys
|
|
|
|
|
image_keys = self.meta.camera_keys
|
|
|
|
|
for cam in image_keys:
|
|
|
|
|
item[cam] = self.image_transforms(item[cam])
|
|
|
|
|
|
|
|
|
|
@@ -545,20 +651,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
f" Selected episodes: {self.episodes},\n"
|
|
|
|
|
f" Number of selected episodes: {self.num_episodes},\n"
|
|
|
|
|
f" Number of selected samples: {self.num_frames},\n"
|
|
|
|
|
f"\n{json.dumps(self.info, indent=4)}\n"
|
|
|
|
|
f"\n{json.dumps(self.meta.info, indent=4)}\n"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
|
|
|
|
# TODO(aliberts): Handle resume
|
|
|
|
|
return {
|
|
|
|
|
"size": 0,
|
|
|
|
|
"episode_index": self.total_episodes if episode_index is None else episode_index,
|
|
|
|
|
"episode_index": self.meta.total_episodes if episode_index is None else episode_index,
|
|
|
|
|
"task_index": None,
|
|
|
|
|
"frame_index": [],
|
|
|
|
|
"timestamp": [],
|
|
|
|
|
"next.done": [],
|
|
|
|
|
**{key: [] for key in self.keys},
|
|
|
|
|
**{key: [] for key in self.image_keys},
|
|
|
|
|
**{key: [] for key in self.meta.keys},
|
|
|
|
|
**{key: [] for key in self.meta.image_keys},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def add_frame(self, frame: dict) -> None:
|
|
|
|
|
@@ -573,7 +679,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
self.episode_buffer["next.done"].append(False)
|
|
|
|
|
|
|
|
|
|
# Save all observed modalities except images
|
|
|
|
|
for key in self.keys:
|
|
|
|
|
for key in self.meta.keys:
|
|
|
|
|
self.episode_buffer[key].append(frame[key])
|
|
|
|
|
|
|
|
|
|
self.episode_buffer["size"] += 1
|
|
|
|
|
@@ -582,7 +688,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Save images
|
|
|
|
|
for cam_key in self.camera_keys:
|
|
|
|
|
for cam_key in self.meta.camera_keys:
|
|
|
|
|
img_path = self.image_writer.get_image_file_path(
|
|
|
|
|
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
|
|
|
|
|
)
|
|
|
|
|
@@ -594,7 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
fpath=img_path,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cam_key in self.image_keys:
|
|
|
|
|
if cam_key in self.meta.image_keys:
|
|
|
|
|
self.episode_buffer[cam_key].append(str(img_path))
|
|
|
|
|
|
|
|
|
|
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
|
|
|
|
@@ -609,17 +715,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
"""
|
|
|
|
|
episode_length = self.episode_buffer.pop("size")
|
|
|
|
|
episode_index = self.episode_buffer["episode_index"]
|
|
|
|
|
if episode_index != self.total_episodes:
|
|
|
|
|
if episode_index != self.meta.total_episodes:
|
|
|
|
|
# TODO(aliberts): Add option to use existing episode_index
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
task_index = self.get_task_index(task)
|
|
|
|
|
task_index = self.meta.get_task_index(task)
|
|
|
|
|
self.episode_buffer["next.done"][-1] = True
|
|
|
|
|
|
|
|
|
|
for key in self.episode_buffer:
|
|
|
|
|
if key in self.image_keys:
|
|
|
|
|
if key in self.meta.image_keys:
|
|
|
|
|
continue
|
|
|
|
|
elif key in self.keys:
|
|
|
|
|
elif key in self.meta.keys:
|
|
|
|
|
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
|
|
|
|
elif key == "episode_index":
|
|
|
|
|
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
|
|
|
|
@@ -628,13 +734,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
else:
|
|
|
|
|
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
|
|
|
|
|
|
|
|
|
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
|
|
|
|
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
|
|
|
|
self.episode_buffer["index"] = torch.arange(
|
|
|
|
|
self.meta.total_frames, self.meta.total_frames + episode_length
|
|
|
|
|
)
|
|
|
|
|
self.meta.add_episode(episode_index, episode_length, task, task_index)
|
|
|
|
|
|
|
|
|
|
self._wait_image_writer()
|
|
|
|
|
self._save_episode_table(episode_index)
|
|
|
|
|
|
|
|
|
|
if encode_videos and len(self.video_keys) > 0:
|
|
|
|
|
if encode_videos and len(self.meta.video_keys) > 0:
|
|
|
|
|
self.encode_videos()
|
|
|
|
|
|
|
|
|
|
# Reset the buffer
|
|
|
|
|
@@ -643,45 +751,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
|
|
|
|
|
def _save_episode_table(self, episode_index: int) -> None:
|
|
|
|
|
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
|
|
|
|
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
|
|
|
|
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
|
|
|
|
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
write_parquet(ep_dataset, ep_data_path)
|
|
|
|
|
|
|
|
|
|
def _save_episode_to_metadata(
|
|
|
|
|
self, episode_index: int, episode_length: int, task: str, task_index: int
|
|
|
|
|
) -> None:
|
|
|
|
|
self.info["total_episodes"] += 1
|
|
|
|
|
self.info["total_frames"] += episode_length
|
|
|
|
|
|
|
|
|
|
if task_index not in self.tasks:
|
|
|
|
|
self.info["total_tasks"] += 1
|
|
|
|
|
self.tasks[task_index] = task
|
|
|
|
|
task_dict = {
|
|
|
|
|
"task_index": task_index,
|
|
|
|
|
"task": task,
|
|
|
|
|
}
|
|
|
|
|
append_jsonlines(task_dict, self.root / TASKS_PATH)
|
|
|
|
|
|
|
|
|
|
chunk = self.get_episode_chunk(episode_index)
|
|
|
|
|
if chunk >= self.total_chunks:
|
|
|
|
|
self.info["total_chunks"] += 1
|
|
|
|
|
|
|
|
|
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
|
|
|
|
self.info["total_videos"] += len(self.video_keys)
|
|
|
|
|
write_json(self.info, self.root / INFO_PATH)
|
|
|
|
|
|
|
|
|
|
episode_dict = {
|
|
|
|
|
"episode_index": episode_index,
|
|
|
|
|
"tasks": [task],
|
|
|
|
|
"length": episode_length,
|
|
|
|
|
}
|
|
|
|
|
self.episode_dicts.append(episode_dict)
|
|
|
|
|
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
|
|
|
|
|
|
|
|
|
|
def clear_episode_buffer(self) -> None:
|
|
|
|
|
episode_index = self.episode_buffer["episode_index"]
|
|
|
|
|
if self.image_writer is not None:
|
|
|
|
|
for cam_key in self.camera_keys:
|
|
|
|
|
for cam_key in self.meta.camera_keys:
|
|
|
|
|
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
|
|
|
|
if img_dir.is_dir():
|
|
|
|
|
shutil.rmtree(img_dir)
|
|
|
|
|
@@ -717,12 +794,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
|
|
|
|
|
def encode_videos(self) -> None:
|
|
|
|
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
|
|
|
|
for episode_index in range(self.total_episodes):
|
|
|
|
|
for key in self.video_keys:
|
|
|
|
|
for episode_index in range(self.meta.total_episodes):
|
|
|
|
|
for key in self.meta.video_keys:
|
|
|
|
|
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
|
|
|
|
# to call self.image_writer here
|
|
|
|
|
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
|
|
|
|
video_path = self.root / self.get_video_file_path(episode_index, key)
|
|
|
|
|
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
|
|
|
|
if video_path.is_file():
|
|
|
|
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
|
|
|
|
continue
|
|
|
|
|
@@ -730,40 +807,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
# since video encoding with ffmpeg is already using multithreading.
|
|
|
|
|
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
|
|
|
|
|
|
|
|
|
def _write_video_info(self) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
|
|
|
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
|
|
|
|
"""
|
|
|
|
|
for key in self.video_keys:
|
|
|
|
|
if key not in self.info["videos"]:
|
|
|
|
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
|
|
|
|
self.info["videos"][key] = get_video_info(video_path)
|
|
|
|
|
|
|
|
|
|
write_json(self.info, self.root / INFO_PATH)
|
|
|
|
|
|
|
|
|
|
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
|
|
|
|
self.hf_dataset = self.load_hf_dataset()
|
|
|
|
|
self.episode_data_index = get_episode_data_index(self.episode_dicts, self.episodes)
|
|
|
|
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
|
|
|
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
|
|
|
|
|
|
|
|
|
if len(self.video_keys) > 0:
|
|
|
|
|
if len(self.meta.video_keys) > 0:
|
|
|
|
|
self.encode_videos()
|
|
|
|
|
self._write_video_info()
|
|
|
|
|
self.meta.write_video_info()
|
|
|
|
|
|
|
|
|
|
if not keep_image_files and self.image_writer is not None:
|
|
|
|
|
shutil.rmtree(self.image_writer.write_dir)
|
|
|
|
|
|
|
|
|
|
video_files = list(self.root.rglob("*.mp4"))
|
|
|
|
|
assert len(video_files) == self.num_episodes * len(self.video_keys)
|
|
|
|
|
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
|
|
|
|
|
|
|
|
parquet_files = list(self.root.rglob("*.parquet"))
|
|
|
|
|
assert len(parquet_files) == self.num_episodes
|
|
|
|
|
|
|
|
|
|
if run_compute_stats:
|
|
|
|
|
self.stop_image_writer()
|
|
|
|
|
self.stats = compute_stats(self)
|
|
|
|
|
write_stats(self.stats, self.root / STATS_PATH)
|
|
|
|
|
self.meta.stats = compute_stats(self)
|
|
|
|
|
write_stats(self.meta.stats, self.root / STATS_PATH)
|
|
|
|
|
self.consolidated = True
|
|
|
|
|
else:
|
|
|
|
|
logging.warning(
|
|
|
|
|
@@ -780,60 +845,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
@classmethod
|
|
|
|
|
def create(
|
|
|
|
|
cls,
|
|
|
|
|
repo_id: str,
|
|
|
|
|
fps: int,
|
|
|
|
|
root: Path | None = None,
|
|
|
|
|
robot: Robot | None = None,
|
|
|
|
|
robot_type: str | None = None,
|
|
|
|
|
keys: list[str] | None = None,
|
|
|
|
|
image_keys: list[str] | None = None,
|
|
|
|
|
video_keys: list[str] = None,
|
|
|
|
|
shapes: dict | None = None,
|
|
|
|
|
names: dict | None = None,
|
|
|
|
|
metadata: LeRobotDatasetMetadata,
|
|
|
|
|
tolerance_s: float = 1e-4,
|
|
|
|
|
image_writer_processes: int = 0,
|
|
|
|
|
image_writer_threads_per_camera: int = 0,
|
|
|
|
|
use_videos: bool = True,
|
|
|
|
|
image_writer_threads: int = 0,
|
|
|
|
|
video_backend: str | None = None,
|
|
|
|
|
) -> "LeRobotDataset":
|
|
|
|
|
"""Create a LeRobot Dataset from scratch in order to record data."""
|
|
|
|
|
obj = cls.__new__(cls)
|
|
|
|
|
obj.repo_id = repo_id
|
|
|
|
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
|
|
|
|
obj.meta = metadata
|
|
|
|
|
obj.repo_id = obj.meta.repo_id
|
|
|
|
|
obj.root = obj.meta.root
|
|
|
|
|
obj.local_files_only = obj.meta.local_files_only
|
|
|
|
|
obj.tolerance_s = tolerance_s
|
|
|
|
|
obj.image_writer = None
|
|
|
|
|
|
|
|
|
|
if robot is not None:
|
|
|
|
|
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
|
|
|
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
|
|
|
|
logging.warning(
|
|
|
|
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
|
|
|
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
|
|
|
|
)
|
|
|
|
|
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
|
|
|
|
obj.start_image_writer(
|
|
|
|
|
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
|
|
|
|
)
|
|
|
|
|
elif (
|
|
|
|
|
robot_type is None
|
|
|
|
|
or keys is None
|
|
|
|
|
or image_keys is None
|
|
|
|
|
or video_keys is None
|
|
|
|
|
or shapes is None
|
|
|
|
|
or names is None
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(video_keys) > 0 and not use_videos:
|
|
|
|
|
raise ValueError()
|
|
|
|
|
|
|
|
|
|
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
|
|
|
|
obj.info = create_empty_dataset_info(
|
|
|
|
|
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
|
|
|
|
)
|
|
|
|
|
write_json(obj.info, obj.root / INFO_PATH)
|
|
|
|
|
if image_writer_processes or image_writer_threads:
|
|
|
|
|
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
|
|
|
|
|
|
|
|
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
|
|
|
|
obj.episode_buffer = obj._create_episode_buffer()
|
|
|
|
|
@@ -849,7 +877,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
obj.image_transforms = None
|
|
|
|
|
obj.delta_timestamps = None
|
|
|
|
|
obj.delta_indices = None
|
|
|
|
|
obj.local_files_only = True
|
|
|
|
|
obj.episode_data_index = None
|
|
|
|
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
|
|
|
|
return obj
|
|
|
|
|
@@ -889,7 +916,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
|
|
|
|
# consistency requirements in future iterations of this class.
|
|
|
|
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
|
|
|
|
if dataset.info != self._datasets[0].info:
|
|
|
|
|
if dataset.meta.info != self._datasets[0].meta.info:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
|
|
|
|
"not yet supported."
|
|
|
|
|
@@ -938,7 +965,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
|
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
|
|
|
"""
|
|
|
|
|
return self._datasets[0].info["fps"]
|
|
|
|
|
return self._datasets[0].meta.info["fps"]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def video(self) -> bool:
|
|
|
|
|
@@ -948,7 +975,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|
|
|
|
|
|
|
|
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
|
|
|
|
"""
|
|
|
|
|
return self._datasets[0].info.get("video", False)
|
|
|
|
|
return self._datasets[0].meta.info.get("video", False)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def features(self) -> datasets.Features:
|
|
|
|
|
|