# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random from functools import partial from pathlib import Path from typing import Protocol from unittest.mock import patch import datasets import numpy as np import PIL.Image import pytest import torch from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, DEFAULT_FEATURES, DEFAULT_PARQUET_PATH, DEFAULT_VIDEO_PATH, get_hf_features_from_features, hf_transform_to_torch, ) from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID, DUMMY_ROBOT_TYPE, DUMMY_VIDEO_INFO, ) class LeRobotDatasetFactory(Protocol): def __call__(self, *args, **kwargs) -> LeRobotDataset: ... def get_task_index(task_dicts: dict, task: str) -> int: tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} return task_to_task_index[task] @pytest.fixture(name="img_tensor_factory", scope="session") def fixture_img_tensor_factory(): def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor: return torch.rand((channels, height, width), dtype=dtype) return _create_img_tensor @pytest.fixture(name="img_array_factory", scope="session") def fixture_img_array_factory(): def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray: if np.issubdtype(dtype, np.unsignedinteger): # Int array in [0, 255] range img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype) elif np.issubdtype(dtype, np.floating): # Float array in [0, 1] range img_array = np.random.rand(height, width, channels).astype(dtype) else: raise ValueError(dtype) return img_array return _create_img_array @pytest.fixture(name="img_factory", scope="session") def fixture_img_factory(img_array_factory): def _create_img(height=100, width=100) -> PIL.Image.Image: img_array = img_array_factory(height=height, width=width) return PIL.Image.fromarray(img_array) return _create_img @pytest.fixture(name="features_factory", scope="session") def fixture_features_factory(): def _create_features( motor_features: dict | None = None, camera_features: dict | None = None, use_videos: bool = True, ) -> dict: if motor_features is None: motor_features = DUMMY_MOTOR_FEATURES if camera_features is None: camera_features = DUMMY_CAMERA_FEATURES if use_videos: camera_ft = { key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items() } else: camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()} return { **motor_features, **camera_ft, **DEFAULT_FEATURES, } return _create_features @pytest.fixture(name="info_factory", scope="session") def fixture_info_factory(features_factory): def _create_info( codebase_version: str = CODEBASE_VERSION, fps: int = DEFAULT_FPS, robot_type: str = DUMMY_ROBOT_TYPE, total_episodes: int = 0, total_frames: int = 0, total_tasks: int = 0, total_videos: int = 0, total_chunks: int = 0, chunks_size: int = DEFAULT_CHUNK_SIZE, data_path: str = DEFAULT_PARQUET_PATH, video_path: str = DEFAULT_VIDEO_PATH, motor_features: dict | None = None, camera_features: dict | None = None, use_videos: bool = True, ) -> dict: if motor_features is None: motor_features = DUMMY_MOTOR_FEATURES if camera_features is None: camera_features = DUMMY_CAMERA_FEATURES features = features_factory(motor_features, camera_features, use_videos) return { "codebase_version": codebase_version, "robot_type": robot_type, "total_episodes": total_episodes, "total_frames": total_frames, "total_tasks": total_tasks, "total_videos": total_videos, "total_chunks": total_chunks, "chunks_size": chunks_size, "fps": fps, "splits": {}, "data_path": data_path, "video_path": video_path if use_videos else None, "features": features, } return _create_info @pytest.fixture(name="stats_factory", scope="session") def fixture_stats_factory(): def _create_stats( features: dict[str] | None = None, ) -> dict: stats = {} for key, ft in features.items(): shape = ft["shape"] dtype = ft["dtype"] if dtype in ["image", "video"]: stats[key] = { "max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(), "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(), "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(), "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), "count": [10], } else: stats[key] = { "max": np.full(shape, 1, dtype=dtype).tolist(), "mean": np.full(shape, 0.5, dtype=dtype).tolist(), "min": np.full(shape, 0, dtype=dtype).tolist(), "std": np.full(shape, 0.25, dtype=dtype).tolist(), "count": [10], } return stats return _create_stats @pytest.fixture(name="episodes_stats_factory", scope="session") def fixture_episodes_stats_factory(stats_factory): def _create_episodes_stats( features: dict[str], total_episodes: int = 3, ) -> dict: episodes_stats = {} for episode_index in range(total_episodes): episodes_stats[episode_index] = { "episode_index": episode_index, "stats": stats_factory(features), } return episodes_stats return _create_episodes_stats @pytest.fixture(name="tasks_factory", scope="session") def fixture_tasks_factory(): def _create_tasks(total_tasks: int = 3) -> int: tasks = {} for task_index in range(total_tasks): task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} tasks[task_index] = task_dict return tasks return _create_tasks @pytest.fixture(name="episodes_factory", scope="session") def fixture_episodes_factory(tasks_factory): def _create_episodes( total_episodes: int = 3, total_frames: int = 400, tasks: dict | None = None, multi_task: bool = False, ): if total_episodes <= 0 or total_frames <= 0: raise ValueError("num_episodes and total_length must be positive integers.") if total_frames < total_episodes: raise ValueError("total_length must be greater than or equal to num_episodes.") if not tasks: min_tasks = 2 if multi_task else 1 total_tasks = random.randint(min_tasks, total_episodes) tasks = tasks_factory(total_tasks) if total_episodes < len(tasks) and not multi_task: raise ValueError("The number of tasks should be less than the number of episodes.") # Generate random lengths that sum up to total_length lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() tasks_list = [task_dict["task"] for task_dict in tasks.values()] num_tasks_available = len(tasks_list) episodes = {} remaining_tasks = tasks_list.copy() for ep_idx in range(total_episodes): num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))) if remaining_tasks: for task in episode_tasks: remaining_tasks.remove(task) episodes[ep_idx] = { "episode_index": ep_idx, "tasks": episode_tasks, "length": lengths[ep_idx], } return episodes return _create_episodes @pytest.fixture(name="hf_dataset_factory", scope="session") def fixture_hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def _create_hf_dataset( features: dict | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, fps: int = DEFAULT_FPS, ) -> datasets.Dataset: if not tasks: tasks = tasks_factory() if not episodes: episodes = episodes_factory() if not features: features = features_factory() timestamp_col = np.array([], dtype=np.float32) frame_index_col = np.array([], dtype=np.int64) episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) for ep_dict in episodes.values(): timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( (episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int)) ) ep_task_index = get_task_index(tasks, ep_dict["tasks"][0]) task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))) index_col = np.arange(len(episode_index_col)) robot_cols = {} for key, ft in features.items(): if ft["dtype"] == "image": robot_cols[key] = [ img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) for _ in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"]) hf_features = get_hf_features_from_features(features) dataset = datasets.Dataset.from_dict( { **robot_cols, "timestamp": timestamp_col, "frame_index": frame_index_col, "episode_index": episode_index_col, "index": index_col, "task_index": task_index, }, features=hf_features, ) dataset.set_transform(hf_transform_to_torch) return dataset return _create_hf_dataset @pytest.fixture(name="lerobot_dataset_metadata_factory", scope="session") def fixture_lerobot_dataset_metadata_factory( info_factory, stats_factory, episodes_stats_factory, tasks_factory, episodes_factory, mock_snapshot_download_factory, ): def _create_lerobot_dataset_metadata( root: Path, repo_id: str = DUMMY_REPO_ID, info: dict | None = None, stats: dict | None = None, episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, ) -> LeRobotDatasetMetadata: if not info: info = info_factory() if not stats: stats = stats_factory(features=info["features"]) if not episodes_stats: episodes_stats = episodes_stats_factory( features=info["features"], total_episodes=info["total_episodes"] ) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: episodes = episodes_factory( total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks ) mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, episodes_stats=episodes_stats, tasks=tasks, episodes=episodes, ) with ( patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download return LeRobotDatasetMetadata(repo_id=repo_id, root=root) return _create_lerobot_dataset_metadata @pytest.fixture(name="lerobot_dataset_factory", scope="session") def fixture_lerobot_dataset_factory( info_factory, stats_factory, episodes_stats_factory, tasks_factory, episodes_factory, hf_dataset_factory, mock_snapshot_download_factory, lerobot_dataset_metadata_factory, ) -> LeRobotDatasetFactory: def _create_lerobot_dataset( root: Path, repo_id: str = DUMMY_REPO_ID, total_episodes: int = 3, total_frames: int = 150, total_tasks: int = 1, multi_task: bool = False, info: dict | None = None, stats: dict | None = None, episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episode_dicts: list[dict] | None = None, hf_dataset: datasets.Dataset | None = None, **kwargs, ) -> LeRobotDataset: if not info: info = info_factory( total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks ) if not stats: stats = stats_factory(features=info["features"]) if not episodes_stats: episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episode_dicts: episode_dicts = episodes_factory( total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks, multi_task=multi_task, ) if not hf_dataset: hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"]) mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, hf_dataset=hf_dataset, ) mock_metadata = lerobot_dataset_metadata_factory( root=root, repo_id=repo_id, info=info, stats=stats, episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, ) with ( patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): mock_metadata_patch.return_value = mock_metadata mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return _create_lerobot_dataset @pytest.fixture(name="empty_lerobot_dataset_factory", scope="session") def fixture_empty_lerobot_dataset_factory() -> LeRobotDatasetFactory: return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)