From 3666ac9346afd3a11732555159cfec552d944ad9 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Sat, 1 Mar 2025 19:07:22 +0000 Subject: [PATCH] WIP UploadDataset --- .../port_datasets/openx_rlds_datatrove.py | 290 +++++++++++++++--- lerobot/common/datasets/aggregate.py | 9 +- 2 files changed, 247 insertions(+), 52 deletions(-) diff --git a/examples/port_datasets/openx_rlds_datatrove.py b/examples/port_datasets/openx_rlds_datatrove.py index b49105445..1fb200d9a 100644 --- a/examples/port_datasets/openx_rlds_datatrove.py +++ b/examples/port_datasets/openx_rlds_datatrove.py @@ -1,21 +1,35 @@ import datetime as dt +import logging +import os +import random +import time from pathlib import Path from datatrove.executor import LocalPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.pipeline.base import PipelineStep +from huggingface_hub import CommitOperationAdd, HfApi, create_commit, preupload_lfs_files +from huggingface_hub.constants import REPOCARD_NAME +from huggingface_hub.utils import HfHubHTTPError + +from lerobot.common.datasets.aggregate import aggregate_datasets +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.datasets.utils import create_lerobot_dataset_card + +BASE_DELAY = 0.1 +MAX_RETRIES = 12 class PortOpenXDataset(PipelineStep): def __init__( self, - raw_dir: Path, + raw_dir: Path | str, repo_id: str = None, image_writer_process: int = 0, image_writer_threads: int = 8, ): super().__init__() - self.raw_dir = raw_dir + self.raw_dir = Path(raw_dir) self.repo_id = repo_id self.image_writer_process = image_writer_process self.image_writer_threads = image_writer_threads @@ -45,8 +59,215 @@ class PortOpenXDataset(PipelineStep): class AggregateDatasets(PipelineStep): + def __init__( + self, + repo_ids: list[str], + aggregated_repo_id: str, + ): + super().__init__() + self.repo_ids = repo_ids + self.aggregated_repo_id = aggregated_repo_id + def run(self, data=None, rank: int = 0, world_size: int = 1): - print("aggregation") + aggregate_datasets(self.repo_ids, self.aggregated_repo_id) + + +class UploadDataset(PipelineStep): + def __init__( + self, + repo_id: str, + branch: str | None = None, + tags: list | None = None, + license: str | None = "apache-2.0", + private: bool = False, + **card_kwargs, + ): + super().__init__() + self.repo_id = repo_id + self.branch = branch + self.tags = tags + self.license = license + self.private = private + self.card_kwargs = card_kwargs + + if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1": + logging.warning( + 'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env ' + "variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1" + ) + + self._repo_init = False + + def _create_repo(self, hub_api): + hub_api.create_repo( + repo_id=self.repo_id, + private=self.private, + repo_type="dataset", + exist_ok=True, + ) + if self.branch: + hub_api.create_branch( + repo_id=self.repo_id, + branch=self.branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) + + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch): + card = create_lerobot_dataset_card( + tags=self.tags, dataset_info=self.meta.info, license=license, **self.card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=self.branch) + + def run(self, data=None, rank: int = 0, world_size: int = 1): + from lerobot.common.utils.utils import init_logging + + init_logging() + + meta = LeRobotDatasetMetadata(self.repo_id) + + # TODO: list files, shard files, upload meta data for rank=0 + filenames = [] + + raise NotImplementedError() + + hub_api = HfApi() + if not self._repo_init: + self._create_repo(hub_api) + self._repo_init = True + + additions = [ + CommitOperationAdd(path_in_repo=filename, path_or_fileobj=meta.root / filename) + for filename in filenames + ] + logging.info(f"Uploading {','.join(filenames)} to the hub...") + preupload_lfs_files( + repo_id=self.repo_id, repo_type="dataset", additions=additions, revision=self.revision + ) + logging.info(f"Upload of {','.join(filenames)} to the hub complete!") + # if self.cleanup: + # for filename in filenames: + # self.local_working_dir.rm(filename) + self.operations.extend(additions) + + def close(self, rank: int = 0): + filelist = list(self.output_mg.get_open_files().keys()) + super().close() + if filelist: + logging.info(f"Starting upload of {len(filelist)} files to {self.dataset}") + self.upload_files(*filelist) + retries = 0 + while True: + try: + create_commit( + self.repo_id, + repo_type="dataset", + operations=self.operations, + commit_message=f"DataTrove upload ({len(self.operations)} files)", + revision=self.revision, + ) + break + except HfHubHTTPError as e: + if "A commit has happened since" in e.server_message: + if retries >= MAX_RETRIES: + logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.") + raise e + logging.info("Commit creation race condition issue. Waiting...") + time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2)) + retries += 1 + else: + raise e + + +def make_port_executor(raw_dir, repo_id, port_job_name, port_log_dir, slurm=True): + kwargs = { + "pipeline": [ + PortOpenXDataset(raw_dir, repo_id), + ], + "logging_dir": str(port_log_dir), + } + + if slurm: + kwargs.update( + { + "job_name": port_job_name, + "tasks": 2048, + "workers": 20, + "time": "08:00:00", + "partition": "hopper-cpu", + "cpus_per_task": 24, + "mem_per_cpu_gb": 2, + "max_array_launch_parallel": True, + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + kwargs.update( + { + "tasks": 1, + "workers": 1, + } + ) + executor = LocalPipelineExecutor(**kwargs) + + return executor + + +def make_aggregate_executor( + repo_ids, aggr_repo_id, port_job_name, aggregate_log_dir, depends=None, slurm=True +): + kwargs = { + "pipeline": [ + AggregateDatasets(repo_ids, aggr_repo_id), + ], + "logging_dir": str(aggregate_log_dir), + "tasks": 1, + "workers": 1, + } + if depends: + kwargs["depends"] = depends + + if slurm: + kwargs.update( + { + "job_name": port_job_name, + "time": "08:00:00", + "partition": "hopper-cpu", + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + executor = LocalPipelineExecutor(**kwargs) + + return executor + + +def make_upload_executor(repo_id, upload_job_name, upload_log_dir, depends=None, slurm=True): + kwargs = { + "pipeline": [ + UploadDataset(repo_id), + ], + "logging_dir": str(upload_log_dir), + "tasks": 1, + "workers": 1, + } + if depends: + kwargs["depends"] = depends + + if slurm: + kwargs.update( + { + "job_name": upload_job_name, + "time": "08:00:00", + "partition": "hopper-cpu", + } + ) + executor = SlurmPipelineExecutor(**kwargs) + else: + executor = LocalPipelineExecutor(**kwargs) + + return executor def main(slurm=True): @@ -54,64 +275,35 @@ def main(slurm=True): # for dir_ in Path("/fsx/remi_cadene/.cache/huggingface/lerobot/cadene").glob("droid_world*"): # shutil.rmtree(dir_) + world = 2048 + raw_dir = "/fsx/mustafa_shukor/droid" port_job_name = "port_openx_droid" + aggregate_job_name = "aggregate_openx_droid" + upload_job_name = "upload_openx_droid" logs_dir = Path("/fsx/remi_cadene/logs") + repo_id = "cadene/droid" now = dt.datetime.now() datetime = f"{now:%Y-%m-%d}_{now:%H-%M-%S}" # datetime = "2025-02-22_11-17-00" port_log_dir = logs_dir / f"{datetime}_{port_job_name}" + aggregate_log_dir = logs_dir / f"{datetime}_{aggregate_job_name}" + upload_log_dir = logs_dir / f"{datetime}_{upload_job_name}" - if slurm: - executor_class = SlurmPipelineExecutor - dist_extra_kwargs = { - "job_name": port_job_name, - "tasks": 2048, - # "workers": 20, # 8 * 16, - "workers": 20, # 8 * 16, - "time": "08:00:00", - "partition": "hopper-cpu", - "cpus_per_task": 24, - "mem_per_cpu_gb": 2, - "max_array_launch_parallel": True, - } - else: - executor_class = LocalPipelineExecutor - dist_extra_kwargs = { - "tasks": 1, - "workers": 1, - } - - port_executor = executor_class( - pipeline=[ - PortOpenXDataset(raw_dir=Path("/fsx/mustafa_shukor/droid"), repo_id=f"cadene/droid_{datetime}"), - ], - logging_dir=str(port_log_dir), - **dist_extra_kwargs, - ) + port_executor = make_port_executor(raw_dir, repo_id, port_job_name, port_log_dir, slurm) port_executor.run() - # if slurm: - # merge_extra_kwargs = {} - # else: - # merge_extra_kwargs = { - # "job_name": "aggregate", - # "time": "00:01:00", - # "partition": "hopper-cpu", - # } + repo_ids = [f"{repo_id}_{datetime}_world_{world}_rank_{rank}" for rank in range(world)] + aggregate_executor = make_aggregate_executor( + repo_ids, repo_id, aggregate_job_name, aggregate_log_dir, port_executor, slurm + ) + aggregate_executor.run() - # merge_executor = executor_class( - # depends=dist_executor, - # pipeline=[ - # Aggregate(), - # ], - # logging_dir=f"/fsx/remi_cadene/logs/openx_rlds_merge", - # tasks=1, - # workers=1, - # **merge_extra_kwargs, - # ) - # merge_executor.run() + upload_executor = make_upload_executor( + repo_id, upload_job_name, upload_log_dir, aggregate_executor, slurm + ) + upload_executor.run() if __name__ == "__main__": diff --git a/lerobot/common/datasets/aggregate.py b/lerobot/common/datasets/aggregate.py index d891e0081..c927f6b35 100644 --- a/lerobot/common/datasets/aggregate.py +++ b/lerobot/common/datasets/aggregate.py @@ -40,17 +40,20 @@ def get_update_episode_and_task_func(episode_index_to_add, task_index_to_global_ return _update -def aggregate_datasets(all_metadata: list[LeRobotDatasetMetadata], repo_id: str, root=None): +def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, aggr_root=None): logging.info("start aggregate_datasets") + + all_metadata = [LeRobotDatasetMetadata(repo_id) for repo_id in repo_ids] + fps, robot_type, features = validate_all_metadata(all_metadata) # Create resulting dataset folder aggr_meta = LeRobotDatasetMetadata.create( - repo_id=repo_id, + repo_id=aggr_repo_id, fps=fps, robot_type=robot_type, features=features, - root=root, + root=aggr_root, ) logging.info("find all tasks")