Dataset v3 (#1412)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
Co-authored-by: Tavish <tavish9.chen@gmail.com>
Co-authored-by: fracapuano <francesco.capuano@huggingface.co>
Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
Michel Aractingi
2025-09-15 09:53:30 +02:00
committed by GitHub
parent d602e8169c
commit f55c6e89f0
50 changed files with 4642 additions and 4092 deletions

View File

@@ -92,11 +92,11 @@ print(dataset.hf_dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
# frame indices associated to the first episode:
episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Then we grab all the image frames from the first camera:
camera_key = dataset.meta.camera_keys[0]

View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from pathlib import Path
def find_missing_workers(completions_dir, world_size):
"""Find workers that are not completed and returns their indices."""
full = list(range(world_size))
completed = []
for path in completions_dir.glob("*"):
if path.name in [".", ".."]:
continue
index = path.name.lstrip("0")
index = 0 if index == "" else int(index)
completed.append(index)
missing_workers = set(full) - set(completed)
return missing_workers
def find_output_files(slurm_dir, worker_indices):
"""Find output files associated to worker indices, and return tuples
of (worker index, output file path)
"""
out_files = []
for path in slurm_dir.glob("*.out"):
_, worker_id = path.name.replace(".out", "").split("_")
worker_id = int(worker_id)
if worker_id in worker_indices:
out_files.append((worker_id, path))
return out_files
def display_error_files(logs_dir, job_name):
executor_path = Path(logs_dir) / job_name / "executor.json"
completions_dir = Path(logs_dir) / job_name / "completions"
with open(executor_path) as f:
executor = json.load(f)
missing_workers = find_missing_workers(completions_dir, executor["world_size"])
for missing in sorted(missing_workers)[::-1]:
print(missing)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--logs-dir",
type=str,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
args = parser.parse_args()
display_error_files(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,430 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import time
from pathlib import Path
import numpy as np
import tensorflow_datasets as tfds
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
DROID_SHARDS = 2048
DROID_FPS = 15
DROID_ROBOT_TYPE = "Franka"
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
DROID_FEATURES = {
# true on first step of the episode
"is_first": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode
"is_last": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# true on last step of the episode if it is a terminal step, True for demos
"is_terminal": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
# language_instruction is also stored as "task" to follow LeRobot standard
"language_instruction": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"language_instruction_2": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"language_instruction_3": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"observation.state.gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"observation.state.cartesian_position": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"observation.state.joint_position": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
# Add this new feature to follow LeRobot standard of using joint position + gripper
"observation.state": {
"dtype": "float32",
"shape": (8,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
},
},
# Initially called wrist_image_left
"observation.images.wrist_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
# Initially called exterior_image_1_left
"observation.images.exterior_1_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
# Initially called exterior_image_2_left
"observation.images.exterior_2_left": {
"dtype": "video",
"shape": (180, 320, 3),
"names": [
"height",
"width",
"channels",
],
},
"action.gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"action.gripper_velocity": {
"dtype": "float32",
"shape": (1,),
"names": {
"axes": ["gripper"],
},
},
"action.cartesian_position": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"action.cartesian_velocity": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"action.joint_position": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
"action.joint_velocity": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"],
},
},
# This feature was called "action" in RLDS dataset and consists of [6x joint velocities, 1x gripper position]
"action.original": {
"dtype": "float32",
"shape": (7,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"],
},
},
# Add this new feature to follow LeRobot standard of using joint position + gripper
"action": {
"dtype": "float32",
"shape": (8,),
"names": {
"axes": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6", "gripper"],
},
},
"discount": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
"reward": {
"dtype": "float32",
"shape": (1,),
"names": None,
},
# Meta data that are the same for all frames in the episode
"task_category": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"building": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"collector_id": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"date": {
"dtype": "string",
"shape": (1,),
"names": None,
},
"camera_extrinsics.wrist_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"camera_extrinsics.exterior_1_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"camera_extrinsics.exterior_2_left": {
"dtype": "float32",
"shape": (6,),
"names": {
"axes": ["x", "y", "z", "roll", "pitch", "yaw"],
},
},
"is_episode_successful": {
"dtype": "bool",
"shape": (1,),
"names": None,
},
}
def is_episode_successful(tf_episode_metadata):
# Adapted from: https://github.com/droid-dataset/droid_policy_learning/blob/dd1020eb20d981f90b5ff07dc80d80d5c0cb108b/robomimic/utils/rlds_utils.py#L8
return "/success/" in tf_episode_metadata["file_path"].numpy().decode()
def generate_lerobot_frames(tf_episode):
m = tf_episode["episode_metadata"]
frame_meta = {
"task_category": m["building"].numpy().decode(),
"building": m["building"].numpy().decode(),
"collector_id": m["collector_id"].numpy().decode(),
"date": m["date"].numpy().decode(),
"camera_extrinsics.wrist_left": m["extrinsics_wrist_cam"].numpy(),
"camera_extrinsics.exterior_1_left": m["extrinsics_exterior_cam_1"].numpy(),
"camera_extrinsics.exterior_2_left": m["extrinsics_exterior_cam_2"].numpy(),
"is_episode_successful": np.array([is_episode_successful(m)]),
}
for f in tf_episode["steps"]:
# Dataset schema slightly adapted from: https://droid-dataset.github.io/droid/the-droid-dataset.html#-dataset-schema
frame = {
"is_first": np.array([f["is_first"].numpy()]),
"is_last": np.array([f["is_last"].numpy()]),
"is_terminal": np.array([f["is_terminal"].numpy()]),
"language_instruction": f["language_instruction"].numpy().decode(),
"language_instruction_2": f["language_instruction_2"].numpy().decode(),
"language_instruction_3": f["language_instruction_3"].numpy().decode(),
"observation.state.gripper_position": f["observation"]["gripper_position"].numpy(),
"observation.state.cartesian_position": f["observation"]["cartesian_position"].numpy(),
"observation.state.joint_position": f["observation"]["joint_position"].numpy(),
"observation.images.wrist_left": f["observation"]["wrist_image_left"].numpy(),
"observation.images.exterior_1_left": f["observation"]["exterior_image_1_left"].numpy(),
"observation.images.exterior_2_left": f["observation"]["exterior_image_2_left"].numpy(),
"action.gripper_position": f["action_dict"]["gripper_position"].numpy(),
"action.gripper_velocity": f["action_dict"]["gripper_velocity"].numpy(),
"action.cartesian_position": f["action_dict"]["cartesian_position"].numpy(),
"action.cartesian_velocity": f["action_dict"]["cartesian_velocity"].numpy(),
"action.joint_position": f["action_dict"]["joint_position"].numpy(),
"action.joint_velocity": f["action_dict"]["joint_velocity"].numpy(),
"discount": np.array([f["discount"].numpy()]),
"reward": np.array([f["reward"].numpy()]),
"action.original": f["action"].numpy(),
}
# language_instruction is also stored as "task" to follow LeRobot standard
frame["task"] = frame["language_instruction"]
# Add this new feature to follow LeRobot standard of using joint position + gripper
frame["observation.state"] = np.concatenate(
[frame["observation.state.joint_position"], frame["observation.state.gripper_position"]]
)
frame["action"] = np.concatenate([frame["action.joint_position"], frame["action.gripper_position"]])
# Meta data that are the same for all frames in the episode
frame.update(frame_meta)
# Cast fp64 to fp32
for key in frame:
if isinstance(frame[key], np.ndarray) and frame[key].dtype == np.float64:
frame[key] = frame[key].astype(np.float32)
yield frame
def port_droid(
raw_dir: Path,
repo_id: str,
push_to_hub: bool = False,
num_shards: int | None = None,
shard_index: int | None = None,
):
dataset_name = raw_dir.parent.name
version = raw_dir.name
data_dir = raw_dir.parent.parent
builder = tfds.builder(f"{dataset_name}/{version}", data_dir=data_dir, version="")
if num_shards is not None:
tfds_num_shards = builder.info.splits["train"].num_shards
if tfds_num_shards != DROID_SHARDS:
raise ValueError(
f"Number of shards of Droid dataset is expected to be {DROID_SHARDS} but is {tfds_num_shards}."
)
if num_shards != tfds_num_shards:
raise ValueError(
f"We only shard over the fixed number of shards provided by tensorflow dataset ({tfds_num_shards}), but {num_shards} shards provided instead."
)
if shard_index >= tfds_num_shards:
raise ValueError(
f"Shard index is greater than the num of shards ({shard_index} >= {num_shards})."
)
raw_dataset = builder.as_dataset(split=f"train[{shard_index}shard]")
else:
raw_dataset = builder.as_dataset(split="train")
lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id,
robot_type=DROID_ROBOT_TYPE,
fps=DROID_FPS,
features=DROID_FEATURES,
)
start_time = time.time()
num_episodes = raw_dataset.cardinality().numpy().item()
logging.info(f"Number of episodes {num_episodes}")
for episode_index, episode in enumerate(raw_dataset):
elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(
f"{episode_index} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
)
for frame in generate_lerobot_frames(episode):
lerobot_dataset.add_frame(frame)
lerobot_dataset.save_episode()
logging.info("Save_episode")
if push_to_hub:
lerobot_dataset.push_to_hub(
# Add openx tag, since it belongs to the openx collection of datasets
tags=["openx"],
private=False,
)
def validate_dataset(repo_id):
"""Sanity check that ensure meta data can be loaded and all files are present."""
meta = LeRobotDatasetMetadata(repo_id)
if meta.total_episodes == 0:
raise ValueError("Number of episodes is 0.")
for ep_idx in range(meta.total_episodes):
data_path = meta.root / meta.get_data_file_path(ep_idx)
if not data_path.exists():
raise ValueError(f"Parquet file is missing in: {data_path}")
for vid_key in meta.video_keys:
vid_path = meta.root / meta.get_video_file_path(ep_idx, vid_key)
if not vid_path.exists():
raise ValueError(f"Video file is missing in: {vid_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload to hub.",
)
parser.add_argument(
"--num-shards",
type=int,
default=None,
help="Number of shards. Can be either None to load the full dataset, or 2048 to load one of the 2048 tensorflow dataset files.",
)
parser.add_argument(
"--shard-index",
type=int,
default=None,
help="Index of the shard. Can be either None to load the full dataset, or in [0,2047] to load one of the 2048 tensorflow dataset files.",
)
args = parser.parse_args()
port_droid(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.utils.utils import init_logging
class AggregateDatasets(PipelineStep):
def __init__(
self,
repo_ids: list[str],
aggregated_repo_id: str,
):
super().__init__()
self.repo_ids = repo_ids
self.aggr_repo_id = aggregated_repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
init_logging()
# Since aggregate_datasets already handles parallel processing internally,
# we only need one worker to run the entire aggregation
if rank == 0:
logging.info(f"Starting aggregation of {len(self.repo_ids)} datasets into {self.aggr_repo_id}")
aggregate_datasets(self.repo_ids, self.aggr_repo_id)
logging.info("Aggregation complete!")
else:
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
def make_aggregate_executor(
repo_ids, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
AggregateDatasets(repo_ids, repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
# For aggregation, we only need 1 task since aggregate_datasets handles everything
kwargs.update(
{
"job_name": job_name,
"tasks": 1, # Only need 1 task for aggregation
"workers": 1, # Only need 1 worker
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": 1,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=Path,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="aggr_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=1, # Changed default to 1 since aggregation doesn't need multiple workers
help="Number of slurm workers. For aggregation, this should be 1.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
repo_ids = [f"{args.repo_id}_world_{DROID_SHARDS}_rank_{rank}" for rank in range(DROID_SHARDS)]
aggregate_executor = make_aggregate_executor(repo_ids, **kwargs)
aggregate_executor.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,162 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
class PortDroidShards(PipelineStep):
def __init__(
self,
raw_dir: Path | str,
repo_id: str = None,
):
super().__init__()
self.raw_dir = Path(raw_dir)
self.repo_id = repo_id
def run(self, data=None, rank: int = 0, world_size: int = 1):
from datasets.utils.tqdm import disable_progress_bars
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
shard_repo_id = f"{self.repo_id}_world_{world_size}_rank_{rank}"
try:
validate_dataset(shard_repo_id)
return
except Exception:
pass # nosec B110 - Dataset doesn't exist yet, continue with porting
port_droid(
self.raw_dir,
shard_repo_id,
push_to_hub=False,
num_shards=world_size,
shard_index=rank,
)
validate_dataset(shard_repo_id)
def make_port_executor(
raw_dir, repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
PortDroidShards(raw_dir, repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": 1,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--raw-dir",
type=Path,
required=True,
help="Directory containing input raw datasets (e.g. `path/to/dataset` or `path/to/dataset/version).",
)
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=Path,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="port_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=2048,
help="Number of slurm workers. It should be less than the maximum number of shards.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
port_executor = make_port_executor(**kwargs)
port_executor.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,281 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep
from huggingface_hub import HfApi
from huggingface_hub.constants import REPOCARD_NAME
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
from lerobot.datasets.utils import create_lerobot_dataset_card
from lerobot.utils.utils import init_logging
class UploadDataset(PipelineStep):
def __init__(
self,
repo_id: str,
branch: str | None = None,
revision: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
private: bool = False,
distant_repo_id: str | None = None,
**card_kwargs,
):
super().__init__()
self.repo_id = repo_id
self.distant_repo_id = self.repo_id if distant_repo_id is None else distant_repo_id
self.branch = branch
self.tags = tags
self.license = license
self.private = private
self.card_kwargs = card_kwargs
self.revision = revision if revision else CODEBASE_VERSION
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", "0") != "1":
logging.warning(
'HF_HUB_ENABLE_HF_TRANSFER is not set to "1". Install hf_transfer and set the env '
"variable for faster uploads:\npip install hf-transfer\nexport HF_HUB_ENABLE_HF_TRANSFER=1"
)
self.create_repo()
def create_repo(self):
logging.info(f"Loading meta data from {self.repo_id}...")
meta = LeRobotDatasetMetadata(self.repo_id)
logging.info(f"Creating repo {self.distant_repo_id}...")
hub_api = HfApi()
hub_api.create_repo(
repo_id=self.distant_repo_id,
private=self.private,
repo_type="dataset",
exist_ok=True,
)
if self.branch:
hub_api.create_branch(
repo_id=self.distant_repo_id,
branch=self.branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
if not hub_api.file_exists(
self.distant_repo_id, REPOCARD_NAME, repo_type="dataset", revision=self.branch
):
card = create_lerobot_dataset_card(
tags=self.tags, dataset_info=meta.info, license=self.license, **self.card_kwargs
)
card.push_to_hub(repo_id=self.distant_repo_id, repo_type="dataset", revision=self.branch)
hub_api.create_tag(self.distant_repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
def list_files_recursively(directory):
base_path = Path(directory)
return [str(file.relative_to(base_path)) for file in base_path.rglob("*") if file.is_file()]
logging.info(f"Listing all local files from {self.repo_id}...")
self.file_paths = list_files_recursively(meta.root)
self.file_paths = sorted(self.file_paths)
def create_chunks(self, lst, n):
from itertools import islice
it = iter(lst)
return [list(islice(it, size)) for size in [len(lst) // n + (i < len(lst) % n) for i in range(n)]]
def create_commits(self, additions):
import logging
import math
import random
import time
from huggingface_hub import create_commit
from huggingface_hub.utils import HfHubHTTPError
FILES_BETWEEN_COMMITS = 10 # noqa: N806
BASE_DELAY = 0.1 # noqa: N806
MAX_RETRIES = 12 # noqa: N806
# Split the files into smaller chunks for faster commit
# and avoiding "A commit has happened since" error
num_chunks = math.ceil(len(additions) / FILES_BETWEEN_COMMITS)
chunks = self.create_chunks(additions, num_chunks)
for chunk in chunks:
retries = 0
while True:
try:
create_commit(
self.distant_repo_id,
repo_type="dataset",
operations=chunk,
commit_message=f"DataTrove upload ({len(chunk)} files)",
revision=self.branch,
)
# TODO: every 100 chunks super_squach_commits()
logging.info("create_commit completed!")
break
except HfHubHTTPError as e:
if "A commit has happened since" in e.server_message:
if retries >= MAX_RETRIES:
logging.error(f"Failed to create commit after {MAX_RETRIES=}. Giving up.")
raise e
logging.info("Commit creation race condition issue. Waiting...")
time.sleep(BASE_DELAY * 2**retries + random.uniform(0, 2))
retries += 1
else:
raise e
def run(self, data=None, rank: int = 0, world_size: int = 1):
import logging
from datasets.utils.tqdm import disable_progress_bars
from huggingface_hub import CommitOperationAdd, preupload_lfs_files
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.utils.utils import init_logging
init_logging()
disable_progress_bars()
chunks = self.create_chunks(self.file_paths, world_size)
file_paths = chunks[rank]
if len(file_paths) == 0:
raise ValueError(file_paths)
logging.info("Pre-uploading LFS files...")
for i, path in enumerate(file_paths):
logging.info(f"{i}: {path}")
meta = LeRobotDatasetMetadata(self.repo_id)
additions = [
CommitOperationAdd(path_in_repo=path, path_or_fileobj=meta.root / path) for path in file_paths
]
preupload_lfs_files(
repo_id=self.distant_repo_id, repo_type="dataset", additions=additions, revision=self.branch
)
logging.info("Creating commits...")
self.create_commits(additions)
logging.info("Done!")
def make_upload_executor(
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
):
kwargs = {
"pipeline": [
UploadDataset(repo_id),
],
"logging_dir": str(logs_dir / job_name),
}
if slurm:
kwargs.update(
{
"job_name": job_name,
"tasks": DROID_SHARDS,
"workers": workers,
"time": "08:00:00",
"partition": partition,
"cpus_per_task": cpus_per_task,
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
}
)
executor = SlurmPipelineExecutor(**kwargs)
else:
kwargs.update(
{
"tasks": DROID_SHARDS,
"workers": 1,
}
)
executor = LocalPipelineExecutor(**kwargs)
return executor
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset, required when push-to-hub is True.",
)
parser.add_argument(
"--logs-dir",
type=Path,
help="Path to logs directory for `datatrove`.",
)
parser.add_argument(
"--job-name",
type=str,
default="upload_droid",
help="Job name used in slurm, and name of the directory created inside the provided logs directory.",
)
parser.add_argument(
"--slurm",
type=int,
default=1,
help="Launch over slurm. Use `--slurm 0` to launch sequentially (useful to debug).",
)
parser.add_argument(
"--workers",
type=int,
default=50,
help="Number of slurm workers. It should be less than the maximum number of shards.",
)
parser.add_argument(
"--partition",
type=str,
help="Slurm partition. Ideally a CPU partition. No need for GPU partition.",
)
parser.add_argument(
"--cpus-per-task",
type=int,
default=8,
help="Number of cpus that each slurm worker will use.",
)
parser.add_argument(
"--mem-per-cpu",
type=str,
default="1950M",
help="Memory per cpu that each worker will use.",
)
init_logging()
args = parser.parse_args()
kwargs = vars(args)
kwargs["slurm"] = kwargs.pop("slurm") == 1
upload_executor = make_upload_executor(**kwargs)
upload_executor.run()
if __name__ == "__main__":
main()