diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py index c698c3a51..b3f7a2506 100644 --- a/lerobot/common/datasets/aggregate.py +++ b/lerobot/common/datasets/aggregate.py @@ -18,6 +18,7 @@ from lerobot.common.datasets.utils import ( concat_video_files, get_parquet_file_size_in_mb, get_video_size_in_mb, + safe_write_dataframe_to_parquet, update_chunk_file_indices, write_info, write_stats, @@ -100,7 +101,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] ] ) fps, robot_type, features = validate_all_metadata(all_metadata) - video_keys = [k for k, v in features.items() if v["dtype"] == "video"] + video_keys = [key for key in features if features[key]["dtype"] == "video"] + image_keys = [key for key in features if features[key]["dtype"] == "image"] # Initialize output dataset metadata dst_meta = LeRobotDatasetMetadata.create( @@ -203,9 +205,21 @@ def aggregate_videos(src_meta, dst_meta, videos_idx): file_idx, ) - # Update the video index tracking - video_idx["chunk_idx"] = chunk_idx - video_idx["file_idx"] = file_idx + if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB: + # Size limit is reached, prepare new parquet file + aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices( + aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE + ) + aggr_path = aggr_root / DEFAULT_DATA_PATH.format( + chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx + ) + aggr_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(aggr_path) + else: + # Update the existing parquet file with new rows + aggr_df = pd.read_parquet(aggr_path) + df = pd.concat([aggr_df, df], ignore_index=True) + safe_write_dataframe_to_parquet(df, aggr_path, image_keys) return videos_idx diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index cf70bbf46..ab04e61e0 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -31,7 +31,7 @@ from datasets import Dataset from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.errors import RevisionNotFoundError -from torch.profiler import record_function + from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image @@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import ( load_nested_dataset, load_stats, load_tasks, + safe_write_dataframe_to_parquet, update_chunk_file_indices, validate_episode_buffer, validate_frame, @@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Write the resulting dataframe from RAM to disk path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path.parent.mkdir(parents=True, exist_ok=True) - if len(self.meta.image_keys) > 0: - datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) - else: - df.to_parquet(path) + safe_write_dataframe_to_parquet(df, path, self.meta.image_keys) # Update the Hugging Face dataset by reloading it. # This process should be fast because only the latest Parquet file has been modified. diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 58dd94007..bdf3eba97 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In features not in episode_buffer: {set(features) - buffer_keys}" ) + + +def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]): + if len(image_keys) > 0: + # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only + datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + else: + df.to_parquet(path) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 8759ac879..1557c3b7a 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) root_init = tmp_path / "init" - dataset_init = lerobot_dataset_factory(root=root_init) + dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1) init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 5e5c762c8..81b9be39b 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = { }, } DUMMY_CAMERA_FEATURES = { - "laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, - "phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, + "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None}, + "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None}, } DEFAULT_FPS = 30 DUMMY_VIDEO_INFO = { diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 165476454..81dbb5753 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import random +import shutil from functools import partial from pathlib import Path from typing import Protocol @@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import ( get_hf_features_from_features, hf_transform_to_torch, ) +from lerobot.common.datasets.video_utils import encode_video_frames from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -95,7 +97,7 @@ def features_factory(): def _create_features( motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = True, + use_videos: bool = False, ) -> dict: if use_videos: camera_ft = { @@ -129,7 +131,7 @@ def info_factory(features_factory): video_path: str = DEFAULT_VIDEO_PATH, motor_features: dict = DUMMY_MOTOR_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES, - use_videos: bool = True, + use_videos: bool = False, ) -> dict: features = features_factory(motor_features, camera_features, use_videos) return { @@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory): return _create_episodes +@pytest.fixture(scope="session") +def create_videos(info_factory, img_array_factory): + def _create_video_directory( + root: Path, + info: dict | None = None, + total_episodes: int = 3, + total_frames: int = 150, + total_tasks: int = 1, + ): + if info is None: + info = info_factory( + total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks + ) + + video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"} + for key, ft in video_feats.items(): + # create and save images + tmp_dir = root / "tmp_images" + tmp_dir.mkdir(parents=True, exist_ok=True) + for frame_index in range(info["total_frames"]): + img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0]) + pil_img = PIL.Image.fromarray(img) + path = tmp_dir / f"frame-{frame_index:06d}.png" + pil_img.save(path) + + video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0) + encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"]) + shutil.rmtree(tmp_dir) + + return _create_video_directory + + @pytest.fixture(scope="session") def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def _create_hf_dataset( @@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar for key, ft in features.items(): if ft["dtype"] == "image": robot_cols[key] = [ - img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) + img_array_factory(height=ft["shape"][1], width=ft["shape"][0]) for _ in range(len(index_col)) ] elif ft["shape"][0] > 1 and ft["dtype"] != "video": @@ -439,6 +473,7 @@ def lerobot_dataset_factory( hf_dataset: datasets.Dataset | None = None, **kwargs, ) -> LeRobotDataset: + # Instantiate objects if info is None: info = info_factory( total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks @@ -448,19 +483,18 @@ def lerobot_dataset_factory( if tasks is None: tasks = tasks_factory(total_tasks=info["total_tasks"]) if episodes_metadata is None: - video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"] episodes_metadata = episodes_factory( features=info["features"], fps=info["fps"], total_episodes=info["total_episodes"], total_frames=info["total_frames"], - video_keys=video_keys, tasks=tasks, multi_task=multi_task, ) if not hf_dataset: hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"]) + # Write data on disk mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 6caa92469..c218d592d 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import ( DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, DEFAULT_TASKS_PATH, + DEFAULT_VIDEO_PATH, INFO_PATH, STATS_PATH, ) @@ -40,6 +41,7 @@ def mock_snapshot_download_factory( create_episodes, hf_dataset_factory, create_hf_dataset, + create_videos, ): """ This factory allows to patch snapshot_download such that when called, it will create expected files rather @@ -91,40 +93,48 @@ def mock_snapshot_download_factory( DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), ] + video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"] + for key in video_keys: + all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)) + allowed_files = filter_repo_objects( all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns ) - has_info = False - has_tasks = False - has_episodes = False - has_stats = False - has_data = False + request_info = False + request_tasks = False + request_episodes = False + request_stats = False + request_data = False + request_videos = False for rel_path in allowed_files: if rel_path.startswith("meta/info.json"): - has_info = True + request_info = True elif rel_path.startswith("meta/stats"): - has_stats = True + request_stats = True elif rel_path.startswith("meta/tasks"): - has_tasks = True + request_tasks = True elif rel_path.startswith("meta/episodes"): - has_episodes = True + request_episodes = True elif rel_path.startswith("data/"): - has_data = True + request_data = True + elif rel_path.startswith("videos/"): + request_videos = True else: raise ValueError(f"{rel_path} not supported.") - if has_info: + if request_info: create_info(local_dir, info) - if has_stats: + if request_stats: create_stats(local_dir, stats) - if has_tasks: + if request_tasks: create_tasks(local_dir, tasks) - if has_episodes: + if request_episodes: create_episodes(local_dir, episodes) - # TODO(rcadene): create_videos? - if has_data: + if request_data: create_hf_dataset(local_dir, hf_dataset) + if request_videos: + create_videos(root=local_dir, info=info) return str(local_dir)