Fix test_datasets

This commit is contained in:
Simon Alibert
2025-05-08 14:57:12 +02:00
parent 320f7e8450
commit 230c7fdfab
4 changed files with 46 additions and 67 deletions

View File

@@ -48,7 +48,6 @@ from lerobot.common.datasets.utils import (
embed_images,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
get_safe_version,
hf_transform_to_torch,
@@ -72,7 +71,6 @@ from lerobot.common.datasets.video_utils import (
get_safe_default_codec,
get_video_info,
)
from lerobot.common.robots import Robot
CODEBASE_VERSION = "v2.1"
@@ -304,10 +302,9 @@ class LeRobotDatasetMetadata:
cls,
repo_id: str,
fps: int,
root: str | Path | None = None,
robot: Robot | None = None,
features: dict,
robot_type: str | None = None,
features: dict | None = None,
root: str | Path | None = None,
use_videos: bool = True,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
@@ -317,33 +314,27 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
)
elif features is None:
raise ValueError(
"Dataset features must either come from a Robot or explicitly passed upon creation."
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
# if robot is not None:
# features = get_features_from_robot(robot, use_videos)
# robot_type = robot.robot_type
# if not all(cam.fps == fps for cam in robot.cameras.values()):
# logging.warning(
# f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
# "In this case, frames from lower fps cameras will be repeated to fill in the blanks."
# )
# check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
features = {**features, **DEFAULT_FEATURES}
# check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator
for key in features:
if "/" in key:
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
@@ -986,10 +977,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
cls,
repo_id: str,
fps: int,
features: dict,
root: str | Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
use_videos: bool = True,
tolerance_s: float = 1e-4,
image_writer_processes: int = 0,
@@ -1001,10 +991,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=robot,
robot_type=robot_type,
features=features,
root=root,
use_videos=use_videos,
)
obj.repo_id = obj.meta.repo_id

View File

@@ -477,9 +477,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
def create_empty_dataset_info(
codebase_version: str,
fps: int,
robot_type: str,
features: dict,
use_videos: bool,
robot_type: str | None = None,
) -> dict:
return {
"codebase_version": codebase_version,

View File

@@ -16,11 +16,11 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
return KochFollowerConfig(**kwargs)
# elif robot_type == "koch_bimanual":
# return KochBimanualRobotConfig(**kwargs)
elif robot_type == "moss":
elif robot_type == "moss_follower":
from .moss_follower.configuration_moss import MossRobotConfig
return MossRobotConfig(**kwargs)
elif robot_type == "so100_leader":
elif robot_type == "so100_follower":
from .so100_follower.config_so100_follower import SO100FollowerConfig
return SO100FollowerConfig(**kwargs)