Aggregate works

This commit is contained in:
Remi Cadene
2025-02-23 18:18:46 +00:00
parent e2e6f6e666
commit c36d2253d0
4 changed files with 106 additions and 26 deletions

View File

@@ -1,10 +1,12 @@
import shutil
from pathlib import Path
import logging
import subprocess
import pandas as pd
import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import write_episode, write_episode_stats, write_info, write_task
from lerobot.common.utils.utils import init_logging
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
@@ -14,7 +16,7 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
robot_type = all_metadata[0].robot_type
features = all_metadata[0].features
for meta in all_metadata:
for meta in tqdm.tqdm(all_metadata):
if fps != meta.fps:
raise ValueError(f"Same fps is expected, but got fps={meta.fps} instead of {fps}.")
if robot_type != meta.robot_type:
@@ -39,6 +41,7 @@ def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_
def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=None):
logging.info("start aggregate_datasets")
fps, robot_type, features = validate_all_metadata(all_metadata)
# Create resulting dataset folder
@@ -50,11 +53,12 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
root=root,
)
logging.info("find all tasks")
# find all tasks, deduplicate them, create new task indices for each dataset
# indexed by dataset index
datasets_task_index_to_aggr_task_index = {}
aggr_task_index = 0
for dataset_index, meta in enumerate(all_metadata):
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)):
task_index_to_aggr_task_index = {}
for task_index, task in meta.tasks.items():
@@ -69,8 +73,9 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
datasets_task_index_to_aggr_task_index[dataset_index] = task_index_to_aggr_task_index
logging.info("cp data and videos")
aggr_episode_index_shift = 0
for dataset_index, meta in enumerate(all_metadata):
for dataset_index, meta in enumerate(tqdm.tqdm(all_metadata)):
# cp data
for episode_index in range(meta.total_episodes):
aggr_episode_index = episode_index + aggr_episode_index_shift
@@ -94,7 +99,10 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
video_path = meta.root / meta.get_video_file_path(episode_index, vid_key)
aggr_video_path = aggr_meta.root / aggr_meta.get_video_file_path(aggr_episode_index, vid_key)
aggr_video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(video_path, aggr_video_path)
# shutil.copy(video_path, aggr_video_path)
copy_command = f"cp {video_path} {aggr_video_path} &"
subprocess.Popen(copy_command, shell=True)
# populate episodes
for episode_index, episode_dict in meta.episodes.items():
@@ -109,11 +117,13 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
# populate info
aggr_meta.info["total_episodes"] += meta.total_episodes
aggr_meta.info["total_frames"] += meta.total_episodes
aggr_meta.info["total_frames"] += meta.total_frames
aggr_meta.info["total_videos"] += len(aggr_meta.video_keys) * meta.total_episodes
aggr_episode_index_shift += meta.total_episodes
logging.info("write meta data")
aggr_meta.info["total_chunks"] = aggr_meta.get_episode_chunk(aggr_episode_index_shift - 1)
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.info['total_episodes']}"}
@@ -133,30 +143,30 @@ def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str,
if __name__ == "__main__":
init_logging()
repo_id = "cadene/droid"
aggr_repo_id = "cadene/droid"
datetime = "2025-02-22_11-23-54"
root = Path(f"/tmp/{repo_id}")
if root.exists():
shutil.rmtree(root)
# root = Path(f"/tmp/{repo_id}")
# if root.exists():
# shutil.rmtree(root)
root = None
all_metadata = [
LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_0"),
LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_1"),
]
# all_metadata = [LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}") for rank in range(2048)]
aggregate_datasets(
all_metadata,
repo_id,
root=root,
)
# aggregate_datasets(
# all_metadata,
# aggr_repo_id,
# root=root,
# )
aggr_dataset = LeRobotDataset(
repo_id=repo_id,
repo_id=aggr_repo_id,
root=root,
)
aggr_dataset.push_to_hub()
aggr_dataset.push_to_hub(tags=["openx"])
# for meta in all_metadata:
# dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root)
# dataset.push_to_hub()
# dataset.push_to_hub(tags=["openx"])