Refactor push_dataset_to_hub (#118)
This commit is contained in:
@@ -1,295 +1,215 @@
|
||||
"""
|
||||
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id pusht \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--revision v1.2 \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import (
|
||||
AlohaProcessor,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||
|
||||
|
||||
def push_lerobot_dataset_to_hub(
|
||||
hf_dataset: Dataset,
|
||||
episode_data_index: dict[str, list[int]],
|
||||
info: dict[str, Any],
|
||||
stats: dict[str, dict[str, torch.Tensor]],
|
||||
root: Path,
|
||||
revision: str,
|
||||
dataset_id: str,
|
||||
community_id: str = "lerobot",
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Pushes a dataset to the Hugging Face Hub.
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
if raw_format == "pusht_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "umi_zarr":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(raw_format)
|
||||
|
||||
Args:
|
||||
hf_dataset (Dataset): The dataset to be pushed.
|
||||
episode_data_index (dict[str, list[int]]): The index of episode data.
|
||||
info (dict[str, Any]): Information about the dataset, eg. fps.
|
||||
stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset.
|
||||
root (Path): The root directory of the dataset.
|
||||
revision (str): The revision of the dataset.
|
||||
dataset_id (str): The ID of the dataset.
|
||||
community_id (str, optional): The ID of the community or the user where the
|
||||
dataset will be stored. Defaults to "lerobot".
|
||||
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
||||
"""
|
||||
if not dry_run:
|
||||
# push to main to indicate latest version
|
||||
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True)
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
# push to version branch
|
||||
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision)
|
||||
|
||||
# create and store meta_data
|
||||
meta_data_dir = root / community_id / dataset_id / "meta_data"
|
||||
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
|
||||
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# info
|
||||
# save info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
|
||||
with open(str(info_path), "w") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
# stats
|
||||
|
||||
# save stats
|
||||
stats_path = meta_data_dir / "stats.safetensors"
|
||||
save_file(flatten_dict(stats), stats_path)
|
||||
|
||||
# episode_data_index
|
||||
# save episode_data_index
|
||||
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
||||
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
||||
save_file(episode_data_index, ep_data_idx_path)
|
||||
|
||||
if not dry_run:
|
||||
api = HfApi()
|
||||
|
||||
def push_meta_data_to_hub(meta_data_dir, repo_id, revision):
|
||||
api = HfApi()
|
||||
|
||||
def upload(filename, revision):
|
||||
api.upload_file(
|
||||
path_or_fileobj=info_path,
|
||||
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=info_path,
|
||||
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
path_or_fileobj=meta_data_dir / filename,
|
||||
path_in_repo=f"meta_data/{filename}",
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
# stats
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=stats_path,
|
||||
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
api.upload_file(
|
||||
path_or_fileobj=ep_data_idx_path,
|
||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
)
|
||||
api.upload_file(
|
||||
path_or_fileobj=ep_data_idx_path,
|
||||
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
||||
repo_id=f"{community_id}/{dataset_id}",
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
# copy in tests folder, the first episode and the meta_data directory
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
||||
f"tests/data/{community_id}/{dataset_id}/train"
|
||||
)
|
||||
if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists():
|
||||
shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data")
|
||||
shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data")
|
||||
upload("info.json", "main")
|
||||
upload("info.json", revision)
|
||||
upload("stats.safetensors", "main")
|
||||
upload("stats.safetensors", revision)
|
||||
upload("episode_data_index.safetensors", "main")
|
||||
upload("episode_data_index.safetensors", revision)
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
root: Path,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
revision: str,
|
||||
dry_run: bool,
|
||||
save_to_disk: bool,
|
||||
tests_data_dir: Path,
|
||||
save_tests_to_disk: bool,
|
||||
fps: int | None,
|
||||
dataset_folder: Path | None = None,
|
||||
dry_run: bool = False,
|
||||
revision: str = "v1.1",
|
||||
community_id: str = "lerobot",
|
||||
no_preprocess: bool = False,
|
||||
path_save_to_disk: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub.
|
||||
video: bool,
|
||||
debug: bool,
|
||||
):
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
Args:
|
||||
dataset_id (str): The ID of the dataset.
|
||||
root (Path): The root directory where the dataset will be downloaded.
|
||||
fps (int | None): The desired frames per second for the dataset.
|
||||
dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None.
|
||||
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
||||
revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1".
|
||||
community_id (str, optional): The ID of the community. Defaults to "lerobot".
|
||||
no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False.
|
||||
path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk.
|
||||
**kwargs: Additional keyword arguments for the preprocessor init method.
|
||||
out_dir = data_dir / community_id / dataset_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / community_id / dataset_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
|
||||
"""
|
||||
if dataset_folder is None:
|
||||
dataset_folder = download_raw(root=root, dataset_id=dataset_id)
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
|
||||
if not no_preprocess:
|
||||
processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs)
|
||||
data_dict, episode_data_index = processor.preprocess()
|
||||
hf_dataset = processor.to_hf_dataset(data_dict)
|
||||
if tests_out_dir.exists():
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
info = {
|
||||
"fps": processor.fps,
|
||||
}
|
||||
stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset)
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
push_lerobot_dataset_to_hub(
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
stats=stats,
|
||||
root=root,
|
||||
revision=revision,
|
||||
dataset_id=dataset_id,
|
||||
community_id=community_id,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if path_save_to_disk:
|
||||
hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk))
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
processor.cleanup()
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
|
||||
class DatasetProcessor(Protocol):
|
||||
"""A class for processing datasets.
|
||||
stats = compute_stats(hf_dataset)
|
||||
|
||||
This class provides methods for validating, preprocessing, and converting datasets.
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
|
||||
Args:
|
||||
folder_path (str): The path to the folder containing the dataset.
|
||||
fps (int | None): The frames per second of the dataset. If None, the default value is used.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
if not dry_run or save_to_disk:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ...
|
||||
if not dry_run:
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir)
|
||||
if video:
|
||||
push_meta_data_to_hub(repo_id, videos_dir)
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if the dataset is valid.
|
||||
if save_tests_to_disk:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
|
||||
Returns:
|
||||
bool: True if the dataset is valid, False otherwise.
|
||||
"""
|
||||
...
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
|
||||
def preprocess(self) -> tuple[dict, dict]:
|
||||
"""Preprocess the dataset.
|
||||
|
||||
Returns:
|
||||
tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data.
|
||||
"""
|
||||
...
|
||||
|
||||
def to_hf_dataset(self, data_dict: dict) -> Dataset:
|
||||
"""Convert the preprocessed data to a Hugging Face dataset.
|
||||
|
||||
Args:
|
||||
data_dict (dict): The preprocessed data.
|
||||
|
||||
Returns:
|
||||
Dataset: The converted Hugging Face dataset.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Get the frames per second of the dataset.
|
||||
|
||||
Returns:
|
||||
int: The frames per second.
|
||||
"""
|
||||
...
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up any resources used by the dataset processor."""
|
||||
...
|
||||
|
||||
|
||||
def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor:
|
||||
if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
||||
return processor
|
||||
# TODO: Propose a registration mechanism for new dataset types
|
||||
raise ValueError(f"Could not guess dataset type for folder {dataset_folder}")
|
||||
# copy meta data to tests directory
|
||||
if Path(tests_meta_data_dir).exists():
|
||||
shutil.rmtree(tests_meta_data_dir)
|
||||
shutil.copytree(meta_data_dir, tests_meta_data_dir)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to process command line arguments and push dataset to Hugging Face Hub.
|
||||
|
||||
Parses command line arguments to get dataset details and conditions under which the dataset
|
||||
is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Push a dataset to the Hugging Face Hub with optional parameters for customization.",
|
||||
epilog="""
|
||||
Example usage:
|
||||
python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess
|
||||
|
||||
This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub,
|
||||
with various parameters to control the processing and uploading behavior.
|
||||
""",
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset-folder",
|
||||
"--data-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.",
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Unique identifier for the dataset to be processed and uploaded.",
|
||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root", type=Path, required=True, help="Root directory where the dataset operations are managed."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Simulate the push process without uploading any data, for testing purposes.",
|
||||
"--raw-format",
|
||||
type=str,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--community-id",
|
||||
@@ -297,41 +217,57 @@ def main():
|
||||
default="lerobot",
|
||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default="v1.0",
|
||||
help="Dataset version identifier to manage different iterations of the dataset.",
|
||||
default="v1.2",
|
||||
help="Codebase version used to generate the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-preprocess",
|
||||
action="store_true",
|
||||
help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.",
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path-save-to-disk",
|
||||
"--save-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="Optional path where the processed dataset can be saved locally.",
|
||||
default="tests/data",
|
||||
help="Directory containing tests artifacts datasets.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-tests-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video",
|
||||
type=int,
|
||||
# TODO(rcadene): enable when video PR merges
|
||||
default=0,
|
||||
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Debug mode process the first episode only.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
push_dataset_to_hub(
|
||||
dataset_folder=args.dataset_folder,
|
||||
dataset_id=args.dataset_id,
|
||||
root=args.root,
|
||||
fps=args.fps,
|
||||
dry_run=args.dry_run,
|
||||
community_id=args.community_id,
|
||||
revision=args.revision,
|
||||
no_preprocess=args.no_preprocess,
|
||||
path_save_to_disk=args.path_save_to_disk,
|
||||
)
|
||||
push_dataset_to_hub(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user