diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index c6eac5e9..a5552f97 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -144,7 +144,8 @@ def push_videos_to_hub(repo_id, videos_dir, revision): def push_dataset_to_hub( - data_dir: Path, + input_data_dir: Path, + output_data_dir: Path, dataset_id: str, raw_format: str | None, community_id: str, @@ -161,34 +162,33 @@ def push_dataset_to_hub( ): repo_id = f"{community_id}/{dataset_id}" - raw_dir = data_dir / f"{dataset_id}_raw" - - out_dir = data_dir / repo_id - meta_data_dir = out_dir / "meta_data" - videos_dir = out_dir / "videos" + meta_data_dir = output_data_dir / "meta_data" + videos_dir = output_data_dir / "videos" tests_out_dir = tests_data_dir / repo_id tests_meta_data_dir = tests_out_dir / "meta_data" tests_videos_dir = tests_out_dir / "videos" - if out_dir.exists(): - shutil.rmtree(out_dir) + if output_data_dir.exists(): + shutil.rmtree(output_data_dir) if tests_out_dir.exists() and save_tests_to_disk: shutil.rmtree(tests_out_dir) - if not raw_dir.exists(): - download_raw(raw_dir, dataset_id) + if not input_data_dir.exists(): + download_raw(input_data_dir, dataset_id) if raw_format is None: # TODO(rcadene, adilzouitine): implement auto_find_raw_format raise NotImplementedError() - # raw_format = auto_find_raw_format(raw_dir) + # raw_format = auto_find_raw_format(input_data_dir) from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format) # convert dataset from original raw format to LeRobot format - hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) + hf_dataset, episode_data_index, info = from_raw_to_lerobot_format( + input_data_dir, output_data_dir, fps, video, debug + ) lerobot_dataset = LeRobotDataset.from_preloaded( repo_id=repo_id, @@ -202,7 +202,7 @@ def push_dataset_to_hub( if save_to_disk: hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved - hf_dataset.save_to_disk(str(out_dir / "train")) + hf_dataset.save_to_disk(str(output_data_dir / "train")) if not dry_run or save_to_disk: # mandatory for upload @@ -236,19 +236,25 @@ def push_dataset_to_hub( fname = f"{key}_episode_{episode_index:06d}.mp4" shutil.copy(videos_dir / fname, tests_videos_dir / fname) - if not save_to_disk and out_dir.exists(): + if not save_to_disk and output_data_dir.exists(): # remove possible temporary files remaining in the output directory - shutil.rmtree(out_dir) + shutil.rmtree(output_data_dir) def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--data-dir", + "--input-data-dir", type=Path, required=True, - help="Root directory containing datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", + help="Root directory containing input raw datasets (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", + ) + parser.add_argument( + "--output-data-dir", + type=Path, + required=True, + help="Root directory containing output dataset (e.g. `data` or `tmp/data` or `/tmp/lerobot/data`).", ) parser.add_argument( "--dataset-id",