Refactor push_dataset_to_hub (#118)

This commit is contained in:
Remi
2024-04-30 14:25:41 +02:00
committed by GitHub
parent 2765877f28
commit e4e739f4f8
25 changed files with 1089 additions and 1192 deletions

View File

@@ -1,295 +1,215 @@
"""
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
installation of neural net specific packages like pytorch, tensorflow, jax.
Example:
```
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id pusht \
--raw-format pusht_zarr \
--community-id lerobot \
--revision v1.2 \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id xarm_lift_medium \
--raw-format xarm_pkl \
--community-id lerobot \
--revision v1.2 \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id aloha_sim_insertion_scripted \
--raw-format aloha_hdf5 \
--community-id lerobot \
--revision v1.2 \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id umi_cup_in_the_wild \
--raw-format umi_zarr \
--community-id lerobot \
--revision v1.2 \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
```
"""
import argparse
import json
import shutil
from pathlib import Path
from typing import Any, Protocol
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import (
AlohaProcessor,
)
from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor
from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor
from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor
from lerobot.common.datasets.utils import compute_stats, flatten_dict
def push_lerobot_dataset_to_hub(
hf_dataset: Dataset,
episode_data_index: dict[str, list[int]],
info: dict[str, Any],
stats: dict[str, dict[str, torch.Tensor]],
root: Path,
revision: str,
dataset_id: str,
community_id: str = "lerobot",
dry_run: bool = False,
) -> None:
"""
Pushes a dataset to the Hugging Face Hub.
def get_from_raw_to_lerobot_format_fn(raw_format):
if raw_format == "pusht_zarr":
from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
elif raw_format == "umi_zarr":
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else:
raise ValueError(raw_format)
Args:
hf_dataset (Dataset): The dataset to be pushed.
episode_data_index (dict[str, list[int]]): The index of episode data.
info (dict[str, Any]): Information about the dataset, eg. fps.
stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset.
root (Path): The root directory of the dataset.
revision (str): The revision of the dataset.
dataset_id (str): The ID of the dataset.
community_id (str, optional): The ID of the community or the user where the
dataset will be stored. Defaults to "lerobot".
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
"""
if not dry_run:
# push to main to indicate latest version
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True)
return from_raw_to_lerobot_format
# push to version branch
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision)
# create and store meta_data
meta_data_dir = root / community_id / dataset_id / "meta_data"
def save_meta_data(info, stats, episode_data_index, meta_data_dir):
meta_data_dir.mkdir(parents=True, exist_ok=True)
# info
# save info
info_path = meta_data_dir / "info.json"
with open(str(info_path), "w") as f:
json.dump(info, f, indent=4)
# stats
# save stats
stats_path = meta_data_dir / "stats.safetensors"
save_file(flatten_dict(stats), stats_path)
# episode_data_index
# save episode_data_index
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
save_file(episode_data_index, ep_data_idx_path)
if not dry_run:
api = HfApi()
def push_meta_data_to_hub(meta_data_dir, repo_id, revision):
api = HfApi()
def upload(filename, revision):
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
repo_id=f"{community_id}/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=info_path,
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
repo_id=f"{community_id}/{dataset_id}",
repo_type="dataset",
path_or_fileobj=meta_data_dir / filename,
path_in_repo=f"meta_data/{filename}",
repo_id=repo_id,
revision=revision,
repo_type="dataset",
)
# stats
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
repo_id=f"{community_id}/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=stats_path,
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
repo_id=f"{community_id}/{dataset_id}",
repo_type="dataset",
revision=revision,
)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
repo_id=f"{community_id}/{dataset_id}",
repo_type="dataset",
)
api.upload_file(
path_or_fileobj=ep_data_idx_path,
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
repo_id=f"{community_id}/{dataset_id}",
repo_type="dataset",
revision=revision,
)
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{community_id}/{dataset_id}/train"
)
if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data")
upload("info.json", "main")
upload("info.json", revision)
upload("stats.safetensors", "main")
upload("stats.safetensors", revision)
upload("episode_data_index.safetensors", "main")
upload("episode_data_index.safetensors", revision)
def push_dataset_to_hub(
data_dir: Path,
dataset_id: str,
root: Path,
raw_format: str | None,
community_id: str,
revision: str,
dry_run: bool,
save_to_disk: bool,
tests_data_dir: Path,
save_tests_to_disk: bool,
fps: int | None,
dataset_folder: Path | None = None,
dry_run: bool = False,
revision: str = "v1.1",
community_id: str = "lerobot",
no_preprocess: bool = False,
path_save_to_disk: str | None = None,
**kwargs,
) -> None:
"""
Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub.
video: bool,
debug: bool,
):
raw_dir = data_dir / f"{dataset_id}_raw"
Args:
dataset_id (str): The ID of the dataset.
root (Path): The root directory where the dataset will be downloaded.
fps (int | None): The desired frames per second for the dataset.
dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None.
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1".
community_id (str, optional): The ID of the community. Defaults to "lerobot".
no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False.
path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk.
**kwargs: Additional keyword arguments for the preprocessor init method.
out_dir = data_dir / community_id / dataset_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
tests_out_dir = tests_data_dir / community_id / dataset_id
tests_meta_data_dir = tests_out_dir / "meta_data"
"""
if dataset_folder is None:
dataset_folder = download_raw(root=root, dataset_id=dataset_id)
if out_dir.exists():
shutil.rmtree(out_dir)
if not no_preprocess:
processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs)
data_dict, episode_data_index = processor.preprocess()
hf_dataset = processor.to_hf_dataset(data_dict)
if tests_out_dir.exists():
shutil.rmtree(tests_out_dir)
info = {
"fps": processor.fps,
}
stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset)
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)
push_lerobot_dataset_to_hub(
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
stats=stats,
root=root,
revision=revision,
dataset_id=dataset_id,
community_id=community_id,
dry_run=dry_run,
)
if path_save_to_disk:
hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk))
if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir)
processor.cleanup()
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
class DatasetProcessor(Protocol):
"""A class for processing datasets.
stats = compute_stats(hf_dataset)
This class provides methods for validating, preprocessing, and converting datasets.
if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_dir / "train"))
Args:
folder_path (str): The path to the folder containing the dataset.
fps (int | None): The frames per second of the dataset. If None, the default value is used.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
if not dry_run or save_to_disk:
# mandatory for upload
save_meta_data(info, stats, episode_data_index, meta_data_dir)
def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ...
if not dry_run:
repo_id = f"{community_id}/{dataset_id}"
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
push_meta_data_to_hub(repo_id, meta_data_dir)
if video:
push_meta_data_to_hub(repo_id, videos_dir)
def is_valid(self) -> bool:
"""Check if the dataset is valid.
if save_tests_to_disk:
# get the first episode
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
Returns:
bool: True if the dataset is valid, False otherwise.
"""
...
test_hf_dataset = test_hf_dataset.with_format(None)
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
def preprocess(self) -> tuple[dict, dict]:
"""Preprocess the dataset.
Returns:
tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data.
"""
...
def to_hf_dataset(self, data_dict: dict) -> Dataset:
"""Convert the preprocessed data to a Hugging Face dataset.
Args:
data_dict (dict): The preprocessed data.
Returns:
Dataset: The converted Hugging Face dataset.
"""
...
@property
def fps(self) -> int:
"""Get the frames per second of the dataset.
Returns:
int: The frames per second.
"""
...
def cleanup(self):
"""Clean up any resources used by the dataset processor."""
...
def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor:
if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
return processor
if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
return processor
if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
return processor
if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
return processor
# TODO: Propose a registration mechanism for new dataset types
raise ValueError(f"Could not guess dataset type for folder {dataset_folder}")
# copy meta data to tests directory
if Path(tests_meta_data_dir).exists():
shutil.rmtree(tests_meta_data_dir)
shutil.copytree(meta_data_dir, tests_meta_data_dir)
def main():
"""
Main function to process command line arguments and push dataset to Hugging Face Hub.
Parses command line arguments to get dataset details and conditions under which the dataset
is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters.
"""
parser = argparse.ArgumentParser(
description="Push a dataset to the Hugging Face Hub with optional parameters for customization.",
epilog="""
Example usage:
python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess
This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub,
with various parameters to control the processing and uploading behavior.
""",
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--dataset-folder",
"--data-dir",
type=Path,
default=None,
help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.",
required=True,
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
)
parser.add_argument(
"--dataset-id",
type=str,
required=True,
help="Unique identifier for the dataset to be processed and uploaded.",
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
)
parser.add_argument(
"--root", type=Path, required=True, help="Root directory where the dataset operations are managed."
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Simulate the push process without uploading any data, for testing purposes.",
"--raw-format",
type=str,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
)
parser.add_argument(
"--community-id",
@@ -297,41 +217,57 @@ def main():
default="lerobot",
help="Community or user ID under which the dataset will be hosted on the Hub.",
)
parser.add_argument(
"--fps",
type=int,
help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.",
)
parser.add_argument(
"--revision",
type=str,
default="v1.0",
help="Dataset version identifier to manage different iterations of the dataset.",
default="v1.2",
help="Codebase version used to generate the dataset.",
)
parser.add_argument(
"--no-preprocess",
action="store_true",
help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.",
"--dry-run",
type=int,
default=0,
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
)
parser.add_argument(
"--path-save-to-disk",
"--save-to-disk",
type=int,
default=1,
help="Save the dataset in the directory specified by `--data-dir`.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,
help="Optional path where the processed dataset can be saved locally.",
default="tests/data",
help="Directory containing tests artifacts datasets.",
)
parser.add_argument(
"--save-tests-to-disk",
type=int,
default=1,
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
)
parser.add_argument(
"--fps",
type=int,
help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
)
parser.add_argument(
"--video",
type=int,
# TODO(rcadene): enable when video PR merges
default=0,
help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
)
parser.add_argument(
"--debug",
type=int,
default=0,
help="Debug mode process the first episode only.",
)
args = parser.parse_args()
push_dataset_to_hub(
dataset_folder=args.dataset_folder,
dataset_id=args.dataset_id,
root=args.root,
fps=args.fps,
dry_run=args.dry_run,
community_id=args.community_id,
revision=args.revision,
no_preprocess=args.no_preprocess,
path_save_to_disk=args.path_save_to_disk,
)
push_dataset_to_hub(**vars(args))
if __name__ == "__main__":