Improve push_dataset_to_hub API + Add unit tests (#231)

Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Thomas Wolf
2024-06-13 15:18:02 +02:00
committed by GitHub
parent c38f535c9f
commit 125bd93e29
11 changed files with 750 additions and 419 deletions

View File

@@ -18,58 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
installation of neural net specific packages like pytorch, tensorflow, jax.
Example:
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
```
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id pusht \
--raw-dir data/pusht_raw \
--raw-format pusht_zarr \
--community-id lerobot \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
--repo-id lerobot/pusht
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id xarm_lift_medium \
--raw-dir data/xarm_lift_medium_raw \
--raw-format xarm_pkl \
--community-id lerobot \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
--repo-id lerobot/xarm_lift_medium
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id aloha_sim_insertion_scripted \
--raw-dir data/aloha_sim_insertion_scripted_raw \
--raw-format aloha_hdf5 \
--community-id lerobot \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
--repo-id lerobot/aloha_sim_insertion_scripted
python lerobot/scripts/push_dataset_to_hub.py \
--data-dir data \
--dataset-id umi_cup_in_the_wild \
--raw-dir data/umi_cup_in_the_wild_raw \
--raw-format umi_zarr \
--community-id lerobot \
--dry-run 1 \
--save-to-disk 1 \
--save-tests-to-disk 0 \
--debug 1
--repo-id lerobot/umi_cup_in_the_wild
```
"""
import argparse
import json
import shutil
import warnings
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi
from huggingface_hub import HfApi, create_branch
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
@@ -85,8 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "aloha_dora":
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
elif raw_format == "dora_parquet":
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else:
@@ -147,39 +128,61 @@ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | Non
def push_dataset_to_hub(
data_dir: Path,
dataset_id: str,
raw_format: str | None,
community_id: str,
revision: str,
dry_run: bool,
save_to_disk: bool,
tests_data_dir: Path,
save_tests_to_disk: bool,
fps: int | None,
video: bool,
batch_size: int,
num_workers: int,
debug: bool,
raw_dir: Path,
raw_format: str,
repo_id: str,
push_to_hub: bool = True,
local_dir: Path | None = None,
fps: int | None = None,
video: bool = True,
batch_size: int = 32,
num_workers: int = 8,
episodes: list[int] | None = None,
force_override: bool = False,
cache_dir: Path = Path("/tmp"),
tests_data_dir: Path | None = None,
):
repo_id = f"{community_id}/{dataset_id}"
# Check repo_id is well formated
if len(repo_id.split("/")) != 2:
raise ValueError(
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
)
user_id, dataset_id = repo_id.split("/")
raw_dir = data_dir / f"{dataset_id}_raw"
# Robustify when `raw_dir` is str instead of Path
raw_dir = Path(raw_dir)
if not raw_dir.exists():
raise NotADirectoryError(
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
)
out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
if local_dir:
# Robustify when `local_dir` is str instead of Path
local_dir = Path(local_dir)
tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data"
tests_videos_dir = tests_out_dir / "videos"
# Send warning if local_dir isn't well formated
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
warnings.warn(
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
stacklevel=1,
)
if out_dir.exists():
shutil.rmtree(out_dir)
# Check we don't override an existing `local_dir` by mistake
if local_dir.exists():
if force_override:
shutil.rmtree(local_dir)
else:
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
if tests_out_dir.exists() and save_tests_to_disk:
shutil.rmtree(tests_out_dir)
meta_data_dir = local_dir / "meta_data"
videos_dir = local_dir / "videos"
else:
# Temporary directory used to store images, videos, meta_data
meta_data_dir = Path(cache_dir) / "meta_data"
videos_dir = Path(cache_dir) / "videos"
# Download the raw dataset if available
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)
@@ -188,14 +191,14 @@ def push_dataset_to_hub(
raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
raw_dir, videos_dir, fps, video, episodes
)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
version=revision,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
@@ -203,103 +206,80 @@ def push_dataset_to_hub(
)
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
if save_to_disk:
if local_dir:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_dir / "train"))
hf_dataset.save_to_disk(str(local_dir / "train"))
if not dry_run or save_to_disk:
if push_to_hub or local_dir:
# mandatory for upload
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if not dry_run:
# TODO(rcadene): token needs to be a str | None
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
if push_to_hub:
hf_dataset.push_to_hub(repo_id, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
if video:
push_videos_to_hub(repo_id, videos_dir, revision="main")
push_videos_to_hub(repo_id, videos_dir, revision=revision)
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
if save_tests_to_disk:
if tests_data_dir:
# get the first episode
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
test_hf_dataset = test_hf_dataset.with_format(None)
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
tests_meta_data = tests_data_dir / repo_id / "meta_data"
save_meta_data(info, stats, episode_data_index, tests_meta_data)
# copy videos of first episode to tests directory
episode_index = 0
tests_videos_dir = tests_data_dir / repo_id / "videos"
tests_videos_dir.mkdir(parents=True, exist_ok=True)
for key in lerobot_dataset.video_frame_keys:
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and out_dir.exists():
# remove possible temporary files remaining in the output directory
shutil.rmtree(out_dir)
if local_dir is None:
# clear cache
shutil.rmtree(meta_data_dir)
shutil.rmtree(videos_dir)
return lerobot_dataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-dir",
"--raw-dir",
type=Path,
required=True,
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
)
parser.add_argument(
"--dataset-id",
type=str,
required=True,
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
)
# TODO(rcadene): add automatic detection of the format
parser.add_argument(
"--raw-format",
type=str,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
required=True,
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
)
parser.add_argument(
"--community-id",
"--repo-id",
type=str,
default="lerobot",
help="Community or user ID under which the dataset will be hosted on the Hub.",
required=True,
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--revision",
type=str,
default=CODEBASE_VERSION,
help="Codebase version used to generate the dataset.",
)
parser.add_argument(
"--dry-run",
type=int,
default=0,
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
)
parser.add_argument(
"--save-to-disk",
type=int,
default=1,
help="Save the dataset in the directory specified by `--data-dir`.",
)
parser.add_argument(
"--tests-data-dir",
"--local-dir",
type=Path,
default="tests/data",
help="Directory containing tests artifacts datasets.",
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
)
parser.add_argument(
"--save-tests-to-disk",
"--push-to-hub",
type=int,
default=1,
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
help="Upload to hub.",
)
parser.add_argument(
"--fps",
@@ -325,10 +305,21 @@ def main():
help="Number of processes of Dataloader for computing the dataset statistics.",
)
parser.add_argument(
"--debug",
"--episodes",
type=int,
nargs="*",
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
)
parser.add_argument(
"--force-override",
type=int,
default=0,
help="Debug mode process the first episode only.",
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
)
parser.add_argument(
"--tests-data-dir",
type=Path,
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
)
args = parser.parse_args()