diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py index 373479a77..5b5768fdc 100644 --- a/lerobot/common/datasets/aggregate.py +++ b/lerobot/common/datasets/aggregate.py @@ -9,9 +9,15 @@ from lerobot.common.datasets.compute_stats import aggregate_stats from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.utils import ( DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_PATH, DEFAULT_EPISODES_PATH, + DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, + concat_video_files, + get_parquet_file_size_in_mb, + get_video_size_in_mb, + update_chunk_file_indices, write_info, write_stats, write_tasks, @@ -41,17 +47,18 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): return fps, robot_type, features -def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks): +def update_episode_and_task(df, episode_index_to_add, old_tasks, new_tasks): def _update(row): row["episode_index"] = row["episode_index"] + episode_index_to_add task = old_tasks.iloc[row["task_index"]].name row["task_index"] = new_tasks.loc[task].task_index.item() return row - return _update + return df.apply(_update, axis=1) -def get_update_meta_func( +def update_meta_data( + df, meta_chunk_index_to_add, meta_file_index_to_add, data_chunk_index_to_add, @@ -74,7 +81,7 @@ def get_update_meta_func( row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add return row - return _update + return df.apply(_update, axis=1) def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None): @@ -117,6 +124,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] aggr_videos_file_idx = dict.fromkeys(video_keys, 0) for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): + # Aggregate episodes meta data meta_chunk_file_ids = { (c, f) for c, f in zip( @@ -128,7 +136,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] for chunk_idx, file_idx in meta_chunk_file_ids: path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) df = pd.read_parquet(path) - update_meta_func = get_update_meta_func( + df = update_meta_data( + df, aggr_meta_chunk_idx, aggr_meta_file_idx, aggr_data_chunk_idx, @@ -137,20 +146,29 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] aggr_videos_file_idx, num_frames, ) - df = df.apply(update_meta_func, axis=1) aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format( chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx ) + + if aggr_path.exists(): + size_in_mb = get_parquet_file_size_in_mb(path) + aggr_size_in_mb = get_parquet_file_size_in_mb(aggr_path) + + if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB: + # Size limit is reached, prepare new parquet file + aggr_meta_chunk_idx, aggr_meta_file_idx = update_chunk_file_indices( + aggr_meta_chunk_idx, aggr_meta_file_idx, DEFAULT_CHUNK_SIZE + ) + else: + # Update the existing parquet file with new rows + aggr_df = pd.read_parquet(aggr_path) + df = pd.concat([aggr_df, df], ignore_index=True) + aggr_path.parent.mkdir(parents=True, exist_ok=True) df.to_parquet(aggr_path) - aggr_meta_file_idx += 1 - if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE: - aggr_meta_file_idx = 0 - aggr_meta_chunk_idx += 1 - - # cp videos + # Aggregate videos if any for key in video_keys: video_chunk_file_ids = { (c, f) @@ -169,17 +187,32 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] chunk_index=aggr_videos_chunk_idx[key], file_index=aggr_videos_file_idx[key], ) - aggr_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy(str(path), str(aggr_path)) + if aggr_path.exists(): + size_in_mb = get_video_size_in_mb(path) + aggr_size_in_mb = get_video_size_in_mb(aggr_path) - # copy_command = f"cp {video_path} {aggr_video_path} &" - # subprocess.Popen(copy_command, shell=True) + if aggr_size_in_mb + size_in_mb >= DEFAULT_VIDEO_FILE_SIZE_IN_MB: + # Size limit is reached, prepare new parquet file + aggr_videos_chunk_idx[key], aggr_videos_file_idx[key] = update_chunk_file_indices( + aggr_videos_chunk_idx[key], aggr_videos_file_idx[key], DEFAULT_CHUNK_SIZE + ) + else: + # Update the existing parquet file with new rows + concat_video_files( + [aggr_path, path], + aggr_root, + key, + aggr_videos_chunk_idx[key], + aggr_videos_file_idx[key], + ) + else: + aggr_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(str(path), str(aggr_path)) - aggr_videos_file_idx[key] += 1 - if aggr_videos_file_idx[key] >= DEFAULT_CHUNK_SIZE: - aggr_videos_file_idx[key] = 0 - aggr_videos_chunk_idx[key] += 1 + # copy_command = f"cp {video_path} {aggr_video_path} &" + # subprocess.Popen(copy_command, shell=True) + # Aggregate data data_chunk_file_ids = { (c, f) for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False) @@ -188,20 +221,28 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) df = pd.read_parquet(path) # TODO(rcadene): update frame index - update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks) - df = df.apply(update_data_func, axis=1) + df = update_episode_and_task(df, num_episodes, meta.tasks, aggr_meta.tasks) aggr_path = aggr_root / DEFAULT_DATA_PATH.format( chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx ) + if aggr_path.exists(): + size_in_mb = get_parquet_file_size_in_mb(path) + aggr_size_in_mb = get_parquet_file_size_in_mb(aggr_path) + + if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB: + # Size limit is reached, prepare new parquet file + aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices( + aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE + ) + else: + # Update the existing parquet file with new rows + aggr_df = pd.read_parquet(aggr_path) + df = pd.concat([aggr_df, df], ignore_index=True) + aggr_path.parent.mkdir(parents=True, exist_ok=True) df.to_parquet(aggr_path) - aggr_data_file_idx += 1 - if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE: - aggr_data_file_idx = 0 - aggr_data_chunk_idx += 1 - num_episodes += meta.total_episodes num_frames += meta.total_frames @@ -209,6 +250,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] write_tasks(aggr_meta.tasks, aggr_meta.root) logging.info("write info") + aggr_meta.info["total_tasks"] = len(aggr_meta.tasks) aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata]) aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata]) aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"} @@ -221,29 +263,21 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] if __name__ == "__main__": init_logging() - repo_id = "cadene/droid" - aggr_repo_id = "cadene/droid" - datetime = "2025-02-22_11-23-54" + aggr_repo_id = "cadene/aggregate_test" + aggr_root = Path(f"/tmp/{aggr_repo_id}") + if aggr_root.exists(): + shutil.rmtree(aggr_root) - # root = Path(f"/tmp/{repo_id}") - # if root.exists(): - # shutil.rmtree(root) - root = None - - # all_metadata = [LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}") for rank in range(2048)] - - # aggregate_datasets( - # all_metadata, - # aggr_repo_id, - # root=root, - # ) - - aggr_dataset = LeRobotDataset( - repo_id=aggr_repo_id, - root=root, + aggregate_datasets( + ["lerobot/aloha_sim_transfer_cube_human", "lerobot/aloha_sim_insertion_human"], + aggr_repo_id, + aggr_root=aggr_root, ) - aggr_dataset.push_to_hub(tags=["openx"]) - # for meta in all_metadata: - # dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root) - # dataset.push_to_hub(tags=["openx"]) + aggr_dataset = LeRobotDataset(repo_id=aggr_repo_id, root=aggr_root) + + for i in tqdm.tqdm(range(len(aggr_dataset))): + aggr_dataset[i] + pass + + aggr_dataset.push_to_hub(tags=["openx"])