Fix test_datasets
This commit is contained in:
@@ -48,7 +48,6 @@ from lerobot.common.datasets.utils import (
|
||||
embed_images,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hf_features_from_features,
|
||||
get_safe_version,
|
||||
hf_transform_to_torch,
|
||||
@@ -72,7 +71,6 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robots import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
||||
@@ -304,10 +302,9 @@ class LeRobotDatasetMetadata:
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
features: dict,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
root: str | Path | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
@@ -317,33 +314,27 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
)
|
||||
elif features is None:
|
||||
raise ValueError(
|
||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
else:
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
# if robot is not None:
|
||||
# features = get_features_from_robot(robot, use_videos)
|
||||
# robot_type = robot.robot_type
|
||||
# if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
# logging.warning(
|
||||
# f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
# "In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
# )
|
||||
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
# TODO(aliberts, rcadene): implement sanity check for features
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
# check if none of the features contains a "/" in their names,
|
||||
# as this would break the dict flattening in the stats computation, which uses '/' as separator
|
||||
for key in features:
|
||||
if "/" in key:
|
||||
raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.")
|
||||
|
||||
obj.tasks, obj.task_to_task_index = {}, {}
|
||||
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
@@ -986,10 +977,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
features: dict,
|
||||
root: str | Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
@@ -1001,10 +991,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
|
||||
@@ -477,9 +477,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
robot_type: str,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
robot_type: str | None = None,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
|
||||
@@ -16,11 +16,11 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
return KochFollowerConfig(**kwargs)
|
||||
# elif robot_type == "koch_bimanual":
|
||||
# return KochBimanualRobotConfig(**kwargs)
|
||||
elif robot_type == "moss":
|
||||
elif robot_type == "moss_follower":
|
||||
from .moss_follower.configuration_moss import MossRobotConfig
|
||||
|
||||
return MossRobotConfig(**kwargs)
|
||||
elif robot_type == "so100_leader":
|
||||
elif robot_type == "so100_follower":
|
||||
from .so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
|
||||
return SO100FollowerConfig(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user