diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index c79e49d9..9d922c8a 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -48,7 +48,6 @@ from lerobot.common.datasets.utils import ( embed_images, get_delta_indices, get_episode_data_index, - get_features_from_robot, get_hf_features_from_features, get_safe_version, hf_transform_to_torch, @@ -72,7 +71,6 @@ from lerobot.common.datasets.video_utils import ( get_safe_default_codec, get_video_info, ) -from lerobot.common.robots import Robot CODEBASE_VERSION = "v2.1" @@ -304,10 +302,9 @@ class LeRobotDatasetMetadata: cls, repo_id: str, fps: int, - root: str | Path | None = None, - robot: Robot | None = None, + features: dict, robot_type: str | None = None, - features: dict | None = None, + root: str | Path | None = None, use_videos: bool = True, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" @@ -317,33 +314,27 @@ class LeRobotDatasetMetadata: obj.root.mkdir(parents=True, exist_ok=False) - if robot is not None: - features = get_features_from_robot(robot, use_videos) - robot_type = robot.robot_type - if not all(cam.fps == fps for cam in robot.cameras.values()): - logging.warning( - f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." - "In this case, frames from lower fps cameras will be repeated to fill in the blanks." - ) - elif features is None: - raise ValueError( - "Dataset features must either come from a Robot or explicitly passed upon creation." - ) - else: - # TODO(aliberts, rcadene): implement sanity check for features - features = {**features, **DEFAULT_FEATURES} + # if robot is not None: + # features = get_features_from_robot(robot, use_videos) + # robot_type = robot.robot_type + # if not all(cam.fps == fps for cam in robot.cameras.values()): + # logging.warning( + # f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." + # "In this case, frames from lower fps cameras will be repeated to fill in the blanks." + # ) - # check if none of the features contains a "/" in their names, - # as this would break the dict flattening in the stats computation, which uses '/' as separator - for key in features: - if "/" in key: - raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.") + # TODO(aliberts, rcadene): implement sanity check for features + features = {**features, **DEFAULT_FEATURES} - features = {**features, **DEFAULT_FEATURES} + # check if none of the features contains a "/" in their names, + # as this would break the dict flattening in the stats computation, which uses '/' as separator + for key in features: + if "/" in key: + raise ValueError(f"Feature names should not contain '/'. Found '/' in feature '{key}'.") obj.tasks, obj.task_to_task_index = {}, {} obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) + obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) @@ -986,10 +977,9 @@ class LeRobotDataset(torch.utils.data.Dataset): cls, repo_id: str, fps: int, + features: dict, root: str | Path | None = None, - robot: Robot | None = None, robot_type: str | None = None, - features: dict | None = None, use_videos: bool = True, tolerance_s: float = 1e-4, image_writer_processes: int = 0, @@ -1001,10 +991,9 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.meta = LeRobotDatasetMetadata.create( repo_id=repo_id, fps=fps, - root=root, - robot=robot, robot_type=robot_type, features=features, + root=root, use_videos=use_videos, ) obj.repo_id = obj.meta.repo_id diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 31a3cca7..581b3c1d 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -477,9 +477,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea def create_empty_dataset_info( codebase_version: str, fps: int, - robot_type: str, features: dict, use_videos: bool, + robot_type: str | None = None, ) -> dict: return { "codebase_version": codebase_version, diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py index 3c543b64..1fa73c57 100644 --- a/lerobot/common/robots/utils.py +++ b/lerobot/common/robots/utils.py @@ -16,11 +16,11 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: return KochFollowerConfig(**kwargs) # elif robot_type == "koch_bimanual": # return KochBimanualRobotConfig(**kwargs) - elif robot_type == "moss": + elif robot_type == "moss_follower": from .moss_follower.configuration_moss import MossRobotConfig return MossRobotConfig(**kwargs) - elif robot_type == "so100_leader": + elif robot_type == "so100_follower": from .so100_follower.config_so100_follower import SO100FollowerConfig return SO100FollowerConfig(**kwargs) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 2b0fc5b1..55a417c3 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -41,7 +41,6 @@ from lerobot.common.datasets.utils import ( ) from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config -from lerobot.common.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID @@ -70,9 +69,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): objects have the same sets of attributes defined. """ # Instantiate both ways - robot = make_robot("koch", mock=True) + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} root_create = tmp_path / "create" - dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) + dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create) root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) @@ -100,22 +99,13 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory): assert dataset.num_frames == len(dataset) -def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): - features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - with pytest.raises( - ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n" - ): - dataset.add_frame({"state": torch.randn(1)}) - - def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) with pytest.raises( ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" ): - dataset.add_frame({"task": "Dummy task"}) + dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task") def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): @@ -124,7 +114,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): with pytest.raises( ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" ): - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) + dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task") def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): @@ -133,7 +123,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): with pytest.raises( ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" ): - dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task") def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): @@ -143,7 +133,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): ValueError, match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), ): - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory): @@ -155,7 +145,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" ), ): - dataset.add_frame({"state": 1.0, "task": "Dummy task"}) + dataset.add_frame({"state": 1.0}, task="Dummy task") def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory): @@ -165,7 +155,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact ValueError, match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), ): - dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) + dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task") def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory): @@ -177,13 +167,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" ), ): - dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"}) + dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task") def test_add_frame(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(1)}, task="Dummy task") dataset.save_episode() assert len(dataset) == 1 @@ -195,7 +185,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2]) @@ -204,7 +194,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -213,7 +203,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -222,7 +212,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -231,7 +221,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -240,7 +230,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) + dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task") dataset.save_episode() assert dataset[0]["state"].ndim == 0 @@ -249,7 +239,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) - dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) + dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task") dataset.save_episode() assert dataset[0]["caption"] == "Dummy caption" @@ -264,7 +254,7 @@ def test_add_frame_image_wrong_shape(image_dataset): ), ): c, h, w = DUMMY_CHW - dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"}) + dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task") def test_add_frame_image_wrong_range(image_dataset): @@ -277,14 +267,14 @@ def test_add_frame_image_wrong_range(image_dataset): Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`. """ dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task") with pytest.raises(FileNotFoundError): dataset.save_episode() def test_add_frame_image(image_dataset): dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -292,7 +282,7 @@ def test_add_frame_image(image_dataset): def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset - dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) + dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -301,7 +291,7 @@ def test_add_frame_image_h_w_c(image_dataset): def test_add_frame_image_uint8(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": image, "task": "Dummy task"}) + dataset.add_frame({"image": image}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -310,7 +300,7 @@ def test_add_frame_image_uint8(image_dataset): def test_add_frame_image_pil(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) - dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) + dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task") dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)