push_to_hub less hardcoded

This commit is contained in:
Thomas Wolf
2024-05-29 11:39:25 +02:00
parent f409bee6b1
commit ce5329cf44

View File

@@ -144,7 +144,8 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
def push_dataset_to_hub(
data_dir: Path,
input_data_dir: Path,
output_data_dir: Path,
dataset_id: str,
raw_format: str | None,
community_id: str,
@@ -161,34 +162,33 @@ def push_dataset_to_hub(
):
repo_id = f"{community_id}/{dataset_id}"
raw_dir = data_dir / f"{dataset_id}_raw"
out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos"
meta_data_dir = output_data_dir / "meta_data"
videos_dir = output_data_dir / "videos"
tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data"
tests_videos_dir = tests_out_dir / "videos"
if out_dir.exists():
shutil.rmtree(out_dir)
if output_data_dir.exists():
shutil.rmtree(output_data_dir)
if tests_out_dir.exists() and save_tests_to_disk:
shutil.rmtree(tests_out_dir)
if not raw_dir.exists():
download_raw(raw_dir, dataset_id)
if not input_data_dir.exists():
download_raw(input_data_dir, dataset_id)
if raw_format is None:
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
raise NotImplementedError()
# raw_format = auto_find_raw_format(raw_dir)
# raw_format = auto_find_raw_format(input_data_dir)
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
# convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
input_data_dir, output_data_dir, fps, video, debug
)
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
@@ -202,7 +202,7 @@ def push_dataset_to_hub(
if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_dir / "train"))
hf_dataset.save_to_disk(str(output_data_dir / "train"))
if not dry_run or save_to_disk:
# mandatory for upload
@@ -236,19 +236,25 @@ def push_dataset_to_hub(
fname = f"{key}_episode_{episode_index:06d}.mp4"
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
if not save_to_disk and out_dir.exists():
if not save_to_disk and output_data_dir.exists():
# remove possible temporary files remaining in the output directory
shutil.rmtree(out_dir)
shutil.rmtree(output_data_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-dir",
"--input-data-dir",
type=Path,
required=True,
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
help="Root directory containing input raw datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
)
parser.add_argument(
"--output-data-dir",
type=Path,
required=True,
help="Root directory containing output dataset (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
)
parser.add_argument(
"--dataset-id",