push_to_hub less hardcoded
This commit is contained in:
@@ -144,7 +144,8 @@ def push_videos_to_hub(repo_id, videos_dir, revision):
|
||||
|
||||
|
||||
def push_dataset_to_hub(
|
||||
data_dir: Path,
|
||||
input_data_dir: Path,
|
||||
output_data_dir: Path,
|
||||
dataset_id: str,
|
||||
raw_format: str | None,
|
||||
community_id: str,
|
||||
@@ -161,34 +162,33 @@ def push_dataset_to_hub(
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
meta_data_dir = output_data_dir / "meta_data"
|
||||
videos_dir = output_data_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
tests_videos_dir = tests_out_dir / "videos"
|
||||
|
||||
if out_dir.exists():
|
||||
shutil.rmtree(out_dir)
|
||||
if output_data_dir.exists():
|
||||
shutil.rmtree(output_data_dir)
|
||||
|
||||
if tests_out_dir.exists() and save_tests_to_disk:
|
||||
shutil.rmtree(tests_out_dir)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, dataset_id)
|
||||
if not input_data_dir.exists():
|
||||
download_raw(input_data_dir, dataset_id)
|
||||
|
||||
if raw_format is None:
|
||||
# TODO(rcadene, adilzouitine): implement auto_find_raw_format
|
||||
raise NotImplementedError()
|
||||
# raw_format = auto_find_raw_format(raw_dir)
|
||||
# raw_format = auto_find_raw_format(input_data_dir)
|
||||
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
input_data_dir, output_data_dir, fps, video, debug
|
||||
)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
@@ -202,7 +202,7 @@ def push_dataset_to_hub(
|
||||
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_dir / "train"))
|
||||
hf_dataset.save_to_disk(str(output_data_dir / "train"))
|
||||
|
||||
if not dry_run or save_to_disk:
|
||||
# mandatory for upload
|
||||
@@ -236,19 +236,25 @@ def push_dataset_to_hub(
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
if not save_to_disk and out_dir.exists():
|
||||
if not save_to_disk and output_data_dir.exists():
|
||||
# remove possible temporary files remaining in the output directory
|
||||
shutil.rmtree(out_dir)
|
||||
shutil.rmtree(output_data_dir)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
"--input-data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
help="Root directory containing input raw datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-data-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Root directory containing output dataset (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-id",
|
||||
|
||||
Reference in New Issue
Block a user