Refactor dataset features

This commit is contained in:
Simon Alibert
2024-11-05 13:10:43 +01:00
parent 757ea175d3
commit aed9f4036a
7 changed files with 172 additions and 185 deletions

View File

@@ -48,6 +48,14 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
"""
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@@ -214,39 +222,25 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
shapes = {key: len(names) for key, names in robot.names.items()}
camera_shapes = {}
for key, cam in robot.cameras.items():
video_key = f"observation.images.{key}"
camera_shapes[video_key] = {
"width": cam.width,
"height": cam.height,
"channels": cam.channels,
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
camera_ft = {
key: {"dtype": "video" if use_videos else "image", **ft}
for key, ft in robot.camera_features.items()
}
keys = list(robot.names)
image_keys = [] if use_videos else list(camera_shapes)
video_keys = list(camera_shapes) if use_videos else []
shapes = {**shapes, **camera_shapes}
names = robot.names
robot_type = robot.robot_type
return robot_type, keys, image_keys, video_keys, shapes, names
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
def create_empty_dataset_info(
codebase_version: str,
fps: int,
robot_type: str,
keys: list[str],
image_keys: list[str],
video_keys: list[str],
shapes: dict,
names: dict,
features: dict,
use_videos: bool,
) -> dict:
return {
"codebase_version": codebase_version,
"data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot_type,
"total_episodes": 0,
"total_frames": 0,
@@ -256,12 +250,9 @@ def create_empty_dataset_info(
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps,
"splits": {},
"keys": keys,
"video_keys": video_keys,
"image_keys": image_keys,
"shapes": shapes,
"names": names,
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
"data_path": DEFAULT_PARQUET_PATH,
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"features": features,
}
@@ -400,6 +391,12 @@ def create_lerobot_dataset_card(
tags: list | None = None, text: str | None = None, info: dict | None = None
) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.configs = [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None: