Improve push_dataset_to_hub API + Add unit tests (#231)
Co-authored-by: Remi <re.cadene@gmail.com> Co-authored-by: Simon Alibert <alibert.sim@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -18,58 +18,39 @@ Use this script to convert your dataset into LeRobot dataset format and upload i
|
||||
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
|
||||
installation of neural net specific packages like pytorch, tensorflow, jax.
|
||||
|
||||
Example:
|
||||
Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
|
||||
```
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id pusht \
|
||||
--raw-dir data/pusht_raw \
|
||||
--raw-format pusht_zarr \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/pusht
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id xarm_lift_medium \
|
||||
--raw-dir data/xarm_lift_medium_raw \
|
||||
--raw-format xarm_pkl \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/xarm_lift_medium
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id aloha_sim_insertion_scripted \
|
||||
--raw-dir data/aloha_sim_insertion_scripted_raw \
|
||||
--raw-format aloha_hdf5 \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/aloha_sim_insertion_scripted
|
||||
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--data-dir data \
|
||||
--dataset-id umi_cup_in_the_wild \
|
||||
--raw-dir data/umi_cup_in_the_wild_raw \
|
||||
--raw-format umi_zarr \
|
||||
--community-id lerobot \
|
||||
--dry-run 1 \
|
||||
--save-to-disk 1 \
|
||||
--save-tests-to-disk 0 \
|
||||
--debug 1
|
||||
--repo-id lerobot/umi_cup_in_the_wild
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, create_branch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
@@ -85,8 +66,8 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_dora":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
@@ -147,39 +128,61 @@ def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | Non
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
revision: str,
|
||||
dry_run: bool,
|
||||
save_to_disk: bool,
|
||||
tests_data_dir: Path,
|
||||
save_tests_to_disk: bool,
|
||||
fps: int | None,
|
||||
video: bool,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
debug: bool,
|
||||
raw_dir: Path,
|
||||
raw_format: str,
|
||||
repo_id: str,
|
||||
push_to_hub: bool = True,
|
||||
local_dir: Path | None = None,
|
||||
fps: int | None = None,
|
||||
video: bool = True,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 8,
|
||||
episodes: list[int] | None = None,
|
||||
force_override: bool = False,
|
||||
cache_dir: Path = Path("/tmp"),
|
||||
tests_data_dir: Path | None = None,
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
# Check repo_id is well formated
|
||||
if len(repo_id.split("/")) != 2:
|
||||
raise ValueError(
|
||||
f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
|
||||
)
|
||||
user_id, dataset_id = repo_id.split("/")
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
# Robustify when `raw_dir` is str instead of Path
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
raise NotADirectoryError(
|
||||
f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
|
||||
f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
|
||||
)
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
if local_dir:
|
||||
# Robustify when `local_dir` is str instead of Path
|
||||
local_dir = Path(local_dir)
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
# Send warning if local_dir isn't well formated
|
||||
if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
|
||||
warnings.warn(
|
||||
f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
# Check we don't override an existing `local_dir` by mistake
|
||||
if local_dir.exists():
|
||||
if force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
else:
|
||||
raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
videos_dir = local_dir / "videos"
|
||||
else:
|
||||
# Temporary directory used to store images, videos, meta_data
|
||||
meta_data_dir = Path(cache_dir) / "meta_data"
|
||||
videos_dir = Path(cache_dir) / "videos"
|
||||
|
||||
# Download the raw dataset if available
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
|
||||
@@ -188,14 +191,14 @@ def push_dataset_to_hub(
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir, videos_dir, fps, video, episodes
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
@@ -203,103 +206,80 @@ def push_dataset_to_hub(
|
||||
)
|
||||
stats = compute_stats(lerobot_dataset, batch_size, num_workers)
|
||||
|
||||
if save_to_disk:
|
||||
if local_dir:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
if not dry_run or save_to_disk:
|
||||
if push_to_hub or local_dir:
|
||||
# mandatory for upload
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if not dry_run:
|
||||
# TODO(rcadene): token needs to be a str | None
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
if save_tests_to_disk:
|
||||
if tests_data_dir:
|
||||
# get the first episode
|
||||
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
|
||||
test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
|
||||
|
||||
test_hf_dataset = test_hf_dataset.with_format(None)
|
||||
test_hf_dataset.save_to_disk(str(tests_out_dir / "train"))
|
||||
test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data_dir)
|
||||
tests_meta_data = tests_data_dir / repo_id / "meta_data"
|
||||
save_meta_data(info, stats, episode_data_index, tests_meta_data)
|
||||
|
||||
# copy videos of first episode to tests directory
|
||||
episode_index = 0
|
||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
for key in lerobot_dataset.video_frame_keys:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and out_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(out_dir)
|
||||
if local_dir is None:
|
||||
# clear cache
|
||||
shutil.rmtree(meta_data_dir)
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
"--raw-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the dataset (e.g. `pusht`, `aloha_sim_insertion_human`), which matches the folder where the data is stored (e.g. `data/pusht`).",
|
||||
help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
|
||||
)
|
||||
# TODO(rcadene): add automatic detection of the format
|
||||
parser.add_argument(
|
||||
"--raw-format",
|
||||
type=str,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`). If not provided, will be detected automatically.",
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--community-id",
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="Community or user ID under which the dataset will be hosted on the Hub.",
|
||||
required=True,
|
||||
help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=CODEBASE_VERSION,
|
||||
help="Codebase version used to generate the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Run everything without uploading to hub, for testing purposes or storing a dataset locally.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-to-disk",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset in the directory specified by `--data-dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default="tests/data",
|
||||
help="Directory containing tests artifacts datasets.",
|
||||
help="When provided, writes the dataset converted to LeRobotDataset format in this directory (e.g. `data/lerobot/aloha_mobile_chair`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-tests-to-disk",
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Save the dataset with 1 episode used for unit tests in the directory specified by `--tests-data-dir`.",
|
||||
help="Upload to hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
@@ -325,10 +305,21 @@ def main():
|
||||
help="Number of processes of Dataloader for computing the dataset statistics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force-override",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Debug mode process the first episode only.",
|
||||
help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tests-data-dir",
|
||||
type=Path,
|
||||
help="When provided, save tests artifacts into the given directory for (e.g. `--tests-data-dir tests/data/lerobot/pusht`).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user