339 lines
13 KiB
Python
339 lines
13 KiB
Python
import argparse
|
|
import json
|
|
import shutil
|
|
from pathlib import Path
|
|
from typing import Any, Protocol
|
|
|
|
import torch
|
|
from datasets import Dataset
|
|
from huggingface_hub import HfApi
|
|
from safetensors.torch import save_file
|
|
|
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_processor import (
|
|
AlohaProcessor,
|
|
)
|
|
from lerobot.common.datasets.push_dataset_to_hub.pusht_processor import PushTProcessor
|
|
from lerobot.common.datasets.push_dataset_to_hub.umi_processor import UmiProcessor
|
|
from lerobot.common.datasets.push_dataset_to_hub.xarm_processor import XarmProcessor
|
|
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
|
|
|
|
|
def push_lerobot_dataset_to_hub(
|
|
hf_dataset: Dataset,
|
|
episode_data_index: dict[str, list[int]],
|
|
info: dict[str, Any],
|
|
stats: dict[str, dict[str, torch.Tensor]],
|
|
root: Path,
|
|
revision: str,
|
|
dataset_id: str,
|
|
community_id: str = "lerobot",
|
|
dry_run: bool = False,
|
|
) -> None:
|
|
"""
|
|
Pushes a dataset to the Hugging Face Hub.
|
|
|
|
Args:
|
|
hf_dataset (Dataset): The dataset to be pushed.
|
|
episode_data_index (dict[str, list[int]]): The index of episode data.
|
|
info (dict[str, Any]): Information about the dataset, eg. fps.
|
|
stats (dict[str, dict[str, torch.Tensor]]): Statistics of the dataset.
|
|
root (Path): The root directory of the dataset.
|
|
revision (str): The revision of the dataset.
|
|
dataset_id (str): The ID of the dataset.
|
|
community_id (str, optional): The ID of the community or the user where the
|
|
dataset will be stored. Defaults to "lerobot".
|
|
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
|
"""
|
|
if not dry_run:
|
|
# push to main to indicate latest version
|
|
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True)
|
|
|
|
# push to version branch
|
|
hf_dataset.push_to_hub(f"{community_id}/{dataset_id}", token=True, revision=revision)
|
|
|
|
# create and store meta_data
|
|
meta_data_dir = root / community_id / dataset_id / "meta_data"
|
|
meta_data_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# info
|
|
info_path = meta_data_dir / "info.json"
|
|
|
|
with open(str(info_path), "w") as f:
|
|
json.dump(info, f, indent=4)
|
|
# stats
|
|
stats_path = meta_data_dir / "stats.safetensors"
|
|
save_file(flatten_dict(stats), stats_path)
|
|
|
|
# episode_data_index
|
|
episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
|
|
ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
|
|
save_file(episode_data_index, ep_data_idx_path)
|
|
|
|
if not dry_run:
|
|
api = HfApi()
|
|
|
|
api.upload_file(
|
|
path_or_fileobj=info_path,
|
|
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
repo_id=f"{community_id}/{dataset_id}",
|
|
repo_type="dataset",
|
|
)
|
|
api.upload_file(
|
|
path_or_fileobj=info_path,
|
|
path_in_repo=str(info_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
repo_id=f"{community_id}/{dataset_id}",
|
|
repo_type="dataset",
|
|
revision=revision,
|
|
)
|
|
|
|
# stats
|
|
api.upload_file(
|
|
path_or_fileobj=stats_path,
|
|
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
repo_id=f"{community_id}/{dataset_id}",
|
|
repo_type="dataset",
|
|
)
|
|
api.upload_file(
|
|
path_or_fileobj=stats_path,
|
|
path_in_repo=str(stats_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
repo_id=f"{community_id}/{dataset_id}",
|
|
repo_type="dataset",
|
|
revision=revision,
|
|
)
|
|
|
|
api.upload_file(
|
|
path_or_fileobj=ep_data_idx_path,
|
|
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
repo_id=f"{community_id}/{dataset_id}",
|
|
repo_type="dataset",
|
|
)
|
|
api.upload_file(
|
|
path_or_fileobj=ep_data_idx_path,
|
|
path_in_repo=str(ep_data_idx_path).replace(f"{root}/{community_id}/{dataset_id}", ""),
|
|
repo_id=f"{community_id}/{dataset_id}",
|
|
repo_type="dataset",
|
|
revision=revision,
|
|
)
|
|
|
|
# copy in tests folder, the first episode and the meta_data directory
|
|
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
|
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
|
|
f"tests/data/{community_id}/{dataset_id}/train"
|
|
)
|
|
if Path(f"tests/data/{community_id}/{dataset_id}/meta_data").exists():
|
|
shutil.rmtree(f"tests/data/{community_id}/{dataset_id}/meta_data")
|
|
shutil.copytree(meta_data_dir, f"tests/data/{community_id}/{dataset_id}/meta_data")
|
|
|
|
|
|
def push_dataset_to_hub(
|
|
dataset_id: str,
|
|
root: Path,
|
|
fps: int | None,
|
|
dataset_folder: Path | None = None,
|
|
dry_run: bool = False,
|
|
revision: str = "v1.1",
|
|
community_id: str = "lerobot",
|
|
no_preprocess: bool = False,
|
|
path_save_to_disk: str | None = None,
|
|
**kwargs,
|
|
) -> None:
|
|
"""
|
|
Download a raw dataset if needed or access a local raw dataset, detect the raw format (e.g. aloha, pusht, umi) and process it accordingly in a common data format which is then pushed to the Hugging Face Hub.
|
|
|
|
Args:
|
|
dataset_id (str): The ID of the dataset.
|
|
root (Path): The root directory where the dataset will be downloaded.
|
|
fps (int | None): The desired frames per second for the dataset.
|
|
dataset_folder (Path | None, optional): The path to the dataset folder. If not provided, the dataset will be downloaded using the dataset ID. Defaults to None.
|
|
dry_run (bool, optional): If True, performs a dry run without actually pushing the dataset. Defaults to False.
|
|
revision (str, optional): Version of the `push_dataset_to_hub.py` codebase used to preprocess the dataset. Defaults to "v1.1".
|
|
community_id (str, optional): The ID of the community. Defaults to "lerobot".
|
|
no_preprocess (bool, optional): If True, does not preprocesses the dataset. Defaults to False.
|
|
path_save_to_disk (str | None, optional): The path to save the dataset to disk. Works when `dry_run` is True, which allows to only save on disk without uploading. By default, the dataset is not saved on disk.
|
|
**kwargs: Additional keyword arguments for the preprocessor init method.
|
|
|
|
|
|
"""
|
|
if dataset_folder is None:
|
|
dataset_folder = download_raw(root=root, dataset_id=dataset_id)
|
|
|
|
if not no_preprocess:
|
|
processor = guess_dataset_type(dataset_folder=dataset_folder, fps=fps, **kwargs)
|
|
data_dict, episode_data_index = processor.preprocess()
|
|
hf_dataset = processor.to_hf_dataset(data_dict)
|
|
|
|
info = {
|
|
"fps": processor.fps,
|
|
}
|
|
stats: dict[str, dict[str, torch.Tensor]] = compute_stats(hf_dataset)
|
|
|
|
push_lerobot_dataset_to_hub(
|
|
hf_dataset=hf_dataset,
|
|
episode_data_index=episode_data_index,
|
|
info=info,
|
|
stats=stats,
|
|
root=root,
|
|
revision=revision,
|
|
dataset_id=dataset_id,
|
|
community_id=community_id,
|
|
dry_run=dry_run,
|
|
)
|
|
if path_save_to_disk:
|
|
hf_dataset.with_format("torch").save_to_disk(dataset_path=str(path_save_to_disk))
|
|
|
|
processor.cleanup()
|
|
|
|
|
|
class DatasetProcessor(Protocol):
|
|
"""A class for processing datasets.
|
|
|
|
This class provides methods for validating, preprocessing, and converting datasets.
|
|
|
|
Args:
|
|
folder_path (str): The path to the folder containing the dataset.
|
|
fps (int | None): The frames per second of the dataset. If None, the default value is used.
|
|
*args: Additional positional arguments.
|
|
**kwargs: Additional keyword arguments.
|
|
"""
|
|
|
|
def __init__(self, folder_path: str, fps: int | None, *args, **kwargs) -> None: ...
|
|
|
|
def is_valid(self) -> bool:
|
|
"""Check if the dataset is valid.
|
|
|
|
Returns:
|
|
bool: True if the dataset is valid, False otherwise.
|
|
"""
|
|
...
|
|
|
|
def preprocess(self) -> tuple[dict, dict]:
|
|
"""Preprocess the dataset.
|
|
|
|
Returns:
|
|
tuple[dict, dict]: A tuple containing two dictionaries representing the preprocessed data.
|
|
"""
|
|
...
|
|
|
|
def to_hf_dataset(self, data_dict: dict) -> Dataset:
|
|
"""Convert the preprocessed data to a Hugging Face dataset.
|
|
|
|
Args:
|
|
data_dict (dict): The preprocessed data.
|
|
|
|
Returns:
|
|
Dataset: The converted Hugging Face dataset.
|
|
"""
|
|
...
|
|
|
|
@property
|
|
def fps(self) -> int:
|
|
"""Get the frames per second of the dataset.
|
|
|
|
Returns:
|
|
int: The frames per second.
|
|
"""
|
|
...
|
|
|
|
def cleanup(self):
|
|
"""Clean up any resources used by the dataset processor."""
|
|
...
|
|
|
|
|
|
def guess_dataset_type(dataset_folder: Path, **processor_kwargs) -> DatasetProcessor:
|
|
if (processor := AlohaProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
return processor
|
|
if (processor := XarmProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
return processor
|
|
if (processor := PushTProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
return processor
|
|
if (processor := UmiProcessor(folder_path=dataset_folder, **processor_kwargs)).is_valid():
|
|
return processor
|
|
# TODO: Propose a registration mechanism for new dataset types
|
|
raise ValueError(f"Could not guess dataset type for folder {dataset_folder}")
|
|
|
|
|
|
def main():
|
|
"""
|
|
Main function to process command line arguments and push dataset to Hugging Face Hub.
|
|
|
|
Parses command line arguments to get dataset details and conditions under which the dataset
|
|
is processed and pushed. It manages dataset preparation and uploading based on the user-defined parameters.
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Push a dataset to the Hugging Face Hub with optional parameters for customization.",
|
|
epilog="""
|
|
Example usage:
|
|
python -m lerobot.scripts.push_dataset_to_hub --dataset-folder /path/to/dataset --dataset-id example_dataset --root /path/to/root --dry-run --revision v2.0 --community-id example_community --fps 30 --path-save-to-disk /path/to/save --no-preprocess
|
|
|
|
This processes and optionally pushes 'example_dataset' located in '/path/to/dataset' to Hugging Face Hub,
|
|
with various parameters to control the processing and uploading behavior.
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dataset-folder",
|
|
type=Path,
|
|
default=None,
|
|
help="The filesystem path to the dataset folder. If not provided, the dataset must be identified and managed by other means.",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset-id",
|
|
type=str,
|
|
required=True,
|
|
help="Unique identifier for the dataset to be processed and uploaded.",
|
|
)
|
|
parser.add_argument(
|
|
"--root", type=Path, required=True, help="Root directory where the dataset operations are managed."
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Simulate the push process without uploading any data, for testing purposes.",
|
|
)
|
|
parser.add_argument(
|
|
"--community-id",
|
|
type=str,
|
|
default="lerobot",
|
|
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
|
)
|
|
parser.add_argument(
|
|
"--fps",
|
|
type=int,
|
|
help="Target frame rate for video or image sequence datasets. Optional and applicable only if the dataset includes temporal media.",
|
|
)
|
|
parser.add_argument(
|
|
"--revision",
|
|
type=str,
|
|
default="v1.0",
|
|
help="Dataset version identifier to manage different iterations of the dataset.",
|
|
)
|
|
parser.add_argument(
|
|
"--no-preprocess",
|
|
action="store_true",
|
|
help="Does not preprocess the dataset, set this flag if you only want dowload the dataset raw.",
|
|
)
|
|
parser.add_argument(
|
|
"--path-save-to-disk",
|
|
type=Path,
|
|
help="Optional path where the processed dataset can be saved locally.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
push_dataset_to_hub(
|
|
dataset_folder=args.dataset_folder,
|
|
dataset_id=args.dataset_id,
|
|
root=args.root,
|
|
fps=args.fps,
|
|
dry_run=args.dry_run,
|
|
community_id=args.community_id,
|
|
revision=args.revision,
|
|
no_preprocess=args.no_preprocess,
|
|
path_save_to_disk=args.path_save_to_disk,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|