From e07cb52baa61c61f99d538f751967b09b3ba1ee8 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 12 May 2025 15:37:02 +0200 Subject: [PATCH] In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing) --- lerobot/common/datasets/aggregate.py | 4 +- lerobot/common/datasets/lerobot_dataset.py | 6 +-- lerobot/common/datasets/utils.py | 8 ++++ tests/datasets/test_datasets.py | 2 +- tests/fixtures/constants.py | 4 +- tests/fixtures/dataset_factories.py | 44 +++++++++++++++++++--- tests/fixtures/hub.py | 42 +++++++++++++-------- 7 files changed, 81 insertions(+), 29 deletions(-) diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py index 2cf58ff5..193a1e67 100644 --- a/lerobot/common/datasets/aggregate.py +++ b/lerobot/common/datasets/aggregate.py @@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import ( concat_video_files, get_parquet_file_size_in_mb, get_video_size_in_mb, + safe_write_dataframe_to_parquet, update_chunk_file_indices, write_info, write_stats, @@ -97,6 +98,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] fps, robot_type, features = validate_all_metadata(all_metadata) video_keys = [key for key in features if features[key]["dtype"] == "video"] + image_keys = [key for key in features if features[key]["dtype"] == "image"] # Create resulting dataset folder aggr_meta = LeRobotDatasetMetadata.create( @@ -259,7 +261,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] # Update the existing parquet file with new rows aggr_df = pd.read_parquet(aggr_path) df = pd.concat([aggr_df, df], ignore_index=True) - df.to_parquet(aggr_path) + safe_write_dataframe_to_parquet(df, aggr_path, image_keys) num_episodes += meta.total_episodes num_frames += meta.total_frames diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 4df31028..ab04e61e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import ( load_nested_dataset, load_stats, load_tasks, + safe_write_dataframe_to_parquet, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Write the resulting dataframe from RAM to disk path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(self.meta.image_keys) > 0: - datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) - else: - df.to_parquet(path) + safe_write_dataframe_to_parquet(df, path, self.meta.image_keys) # Update the Hugging Face dataset by reloading it. # This process should be fast because only the latest Parquet file has been modified. diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 58dd9400..bdf3eba9 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In features not in episode_buffer: {set(features) - buffer_keys}" ) + + +def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]): + if len(image_keys) > 0: + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + else: + df.to_parquet(path) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 8759ac87..1557c3b7 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) root_init = tmp_path / "init" - dataset_init = lerobot_dataset_factory(root=root_init) + dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 5e5c762c..81b9be39 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, - "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None}, + "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 DUMMY_VIDEO_INFO = { diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 16547645..81dbb575 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random +import shutil from functools import partial from pathlib import Path from typing import Protocol @@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import ( get_hf_features_from_features, hf_transform_to_torch, ) +from lerobot.common.datasets.video_utils import encode_video_frames from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -95,7 +97,7 @@ def features_factory(): def _create_features( motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = True, + use_videos: bool = False, ) -> dict: if use_videos: camera_ft = { @@ -129,7 +131,7 @@ def info_factory(features_factory): video_path: str = DEFAULT_VIDEO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = True, + use_videos: bool = False, ) -> dict: features = features_factory(motor_features, camera_features, use_videos) return { @@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory): return _create_episodes +@pytest.fixture(scope="session") +def create_videos(info_factory, img_array_factory): + def _create_video_directory( + root: Path, + info: dict | None = None, + total_episodes: int = 3, + total_frames: int = 150, + total_tasks: int = 1, + ): + if info is None: + info = info_factory( + total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + ) + + video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"} + for key, ft in video_feats.items(): + # create and save images + tmp_dir = root / "tmp_images" + tmp_dir.mkdir(parents=True, exist_ok=True) + for frame_index in range(info["total_frames"]): + img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0]) + pil_img = PIL.Image.fromarray(img) + path = tmp_dir / f"frame-{frame_index:06d}.png" + pil_img.save(path) + + video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0) + encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"]) + shutil.rmtree(tmp_dir) + + return _create_video_directory + + @pytest.fixture(scope="session") def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def _create_hf_dataset( @@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar for key, ft in features.items(): if ft["dtype"] == "image": robot_cols[key] = [ - img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) + img_array_factory(height=ft["shape"][1], width=ft["shape"][0]) for _ in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": @@ -439,6 +473,7 @@ def lerobot_dataset_factory( hf_dataset: datasets.Dataset | None = None, **kwargs, ) -> LeRobotDataset: + # Instantiate objects if info is None: info = info_factory( total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks @@ -448,19 +483,18 @@ def lerobot_dataset_factory( if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) if episodes_metadata is None: - video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] episodes_metadata = episodes_factory( features=info["features"], fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], - video_keys=video_keys, tasks=tasks, multi_task=multi_task, ) if not hf_dataset: hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"]) + # Write data on disk mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 6caa9246..c218d592 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_TASKS_PATH, + DEFAULT_VIDEO_PATH, INFO_PATH, STATS_PATH, ) @@ -40,6 +41,7 @@ def mock_snapshot_download_factory( create_episodes, hf_dataset_factory, create_hf_dataset, + create_videos, ): """ This factory allows to patch snapshot_download such that when called, it will create expected files rather @@ -91,40 +93,48 @@ def mock_snapshot_download_factory( DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), ] + video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"] + for key in video_keys: + all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)) + allowed_files = filter_repo_objects( all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns ) - has_info = False - has_tasks = False - has_episodes = False - has_stats = False - has_data = False + request_info = False + request_tasks = False + request_episodes = False + request_stats = False + request_data = False + request_videos = False for rel_path in allowed_files: if rel_path.startswith("meta/info.json"): - has_info = True + request_info = True elif rel_path.startswith("meta/stats"): - has_stats = True + request_stats = True elif rel_path.startswith("meta/tasks"): - has_tasks = True + request_tasks = True elif rel_path.startswith("meta/episodes"): - has_episodes = True + request_episodes = True elif rel_path.startswith("data/"): - has_data = True + request_data = True + elif rel_path.startswith("videos/"): + request_videos = True else: raise ValueError(f"{rel_path} not supported.") - if has_info: + if request_info: create_info(local_dir, info) - if has_stats: + if request_stats: create_stats(local_dir, stats) - if has_tasks: + if request_tasks: create_tasks(local_dir, tasks) - if has_episodes: + if request_episodes: create_episodes(local_dir, episodes) - # TODO(rcadene): create_videos? - if has_data: + if request_data: create_hf_dataset(local_dir, hf_dataset) + if request_videos: + create_videos(root=local_dir, info=info) return str(local_dir)