Refactor dataset features

This commit is contained in:
Simon Alibert
2024-11-05 13:10:43 +01:00
parent 757ea175d3
commit aed9f4036a
7 changed files with 172 additions and 185 deletions

View File

@@ -30,11 +30,11 @@ from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import (
DEFAULT_FEATURES,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
_get_info_from_robot,
append_jsonlines,
check_delta_timestamps,
check_timestamps_sync,
@@ -43,6 +43,7 @@ from lerobot.common.datasets.utils import (
create_empty_dataset_info,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hub_safe_version,
hf_transform_to_torch,
load_episodes,
@@ -116,7 +117,7 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int:
@@ -128,15 +129,20 @@ class LeRobotDatasetMetadata:
return self.info["data_path"]
@property
def videos_path(self) -> str | None:
def video_path(self) -> str | None:
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
return self.info["video_path"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"]
@property
def features(self) -> dict[str, dict]:
""""""
return self.info["features"]
@property
def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.)."""
@@ -145,22 +151,27 @@ class LeRobotDatasetMetadata:
@property
def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images."""
return self.info["image_keys"]
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
@property
def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos."""
return self.info["video_keys"]
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property
def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method)."""
return self.image_keys + self.video_keys
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def names(self) -> dict[list[str]]:
def names(self) -> dict[str, list[str]]:
"""Names of the various dimensions of vector modalities."""
return self.info["names"]
return {key: ft["names"] for key, ft in self.features.items()}
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
@property
def total_episodes(self) -> int:
@@ -187,11 +198,6 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return self.info["shapes"]
@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
@@ -253,45 +259,33 @@ class LeRobotDatasetMetadata:
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
keys: list[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
features: dict | None = None,
use_videos: bool = True,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.image_writer = None
if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
features = get_features_from_robot(robot)
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
)
elif (
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
elif robot_type is None or features is None:
raise ValueError(
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
if len(video_keys) > 0 and not use_videos:
raise ValueError()
else:
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
)
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True
return obj
@@ -509,6 +503,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
# return hf_dataset.with_format("torch") TODO
return hf_dataset
@property
@@ -662,8 +657,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": None,
"frame_index": [],
"timestamp": [],
"next.done": [],
**{key: [] for key in self.meta.keys},
**{key: [] for key in self.meta.features},
**{key: [] for key in self.meta.image_keys},
}
@@ -845,7 +839,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod
def create(
cls,
metadata: LeRobotDatasetMetadata,
repo_id: str,
fps: int,
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
use_videos: bool = True,
tolerance_s: float = 1e-4,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
@@ -853,7 +853,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls)
obj.meta = metadata
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=robot,
robot_type=robot_type,
features=features,
use_videos=use_videos,
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only