From 49ae3e19e1cca2326282b4ec8d76c09da84bbab5 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 27 Dec 2024 17:49:47 +0100 Subject: [PATCH] Add clone, delete, WIP on remove_episode, drop_frame --- lerobot/common/datasets/lerobot_dataset.py | 29 ++++++++ lerobot/scripts/manage_dataset.py | 84 ++++++++++++++++++---- 2 files changed, 101 insertions(+), 12 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index f40476e8b..89c83a2e0 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -964,6 +964,35 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.video_backend = video_backend if video_backend is not None else "pyav" return obj + def clone(self, new_repo_id: str, new_root: str | Path | None = None) -> "LeRobotDataset": + return LeRobotDataset.create( + repo_id=new_repo_id, + fps=self.fps, + root=new_root, + robot=self.robot, + robot_type=self.robot_type, + features=self.features, + use_videos=self.use_videos, + tolerance_s=self.tolerance_s, + image_writer_processes=self.image_writer_processes, + image_writer_threads=self.image_writer_threads, + video_backend=self.video_backend, + ) + + def delete(self): + """Delete the dataset locally. If it was push to hub, you can still access it by downloading it again.""" + shutil.rmtree(self.root) + + def remove_episode(self, episode: int | list[int]): + if isinstance(episode, int): + episode = [episode] + + for ep in episode: + self.meta.info + + def drop_frame(self, episode_range: dict[int, tuple[int]]): + pass + class MultiLeRobotDataset(torch.utils.data.Dataset): """A dataset consisting of multiple underlying `LeRobotDataset`s. diff --git a/lerobot/scripts/manage_dataset.py b/lerobot/scripts/manage_dataset.py index 312fb7f1c..f13aa0d80 100644 --- a/lerobot/scripts/manage_dataset.py +++ b/lerobot/scripts/manage_dataset.py @@ -28,6 +28,28 @@ from pathlib import Path from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +def parse_episode_range_string(ep_range_str): + parts = ep_range_str.split("-") + if len(parts) != 3: + raise ValueError( + f"Invalid episode range string '{ep_range_str}'. Expected format: 'EP-FROM-TO', e.g., '1-5-10'." + ) + ep, start, end = parts + return int(ep), int(start), int(end) + + +def parse_episode_range_strings(ep_range_strings): + ep_ranges = {} + for ep_range_str in ep_range_strings: + ep, start, end = parse_episode_range_string(ep_range_str) + if ep not in ep_ranges: + ep_ranges[ep] = [] + ep_ranges[ep].append((start, end)) + + return ep_ranges + + if __name__ == "__main__": parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="mode", required=True) @@ -38,7 +60,7 @@ if __name__ == "__main__": "--root", type=Path, default=None, - help="Root directory where the dataset will be stored (e.g. 'dataset/path').", + help="Root directory where the dataset is stored (e.g. 'dataset/path').", ) base_parser.add_argument( "--repo-id", @@ -53,7 +75,6 @@ if __name__ == "__main__": help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", ) - ############################################################################ # consolidate parser_conso = subparsers.add_parser("consolidate", parents=[base_parser]) @@ -93,6 +114,45 @@ if __name__ == "__main__": help="Create a private dataset repository on the Hugging Face Hub. Push publicly by default.", ) + ############################################################################ + # clone + parser_clone = subparsers.add_parser("clone", parents=[base_parser]) + parser_clone.add_argument( + "--root", + type=Path, + default=None, + help="New root directory where the dataset is stored (e.g. 'dataset/path').", + ) + parser_clone.add_argument( + "--new-repo-id", + type=str, + help="New dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", + ) + + ############################################################################ + # delete + parser_del = subparsers.add_parser("delete", parents=[base_parser]) + + ############################################################################ + # remove_episode + parser_rm_ep = subparsers.add_parser("remove_episode", parents=[base_parser]) + parser_rm_ep.add_argument( + "--episode", + type=int, + nargs="*", + help="List of one or several episodes to be removed from the dataset locally.", + ) + + ############################################################################ + # drop_frame + parser_drop_frame = subparsers.add_parser("drop_frame", parents=[base_parser]) + parser_rm_ep.add_argument( + "--episode-range", + type=str, + nargs="*", + help="List of one or several frame ranges per episode to be removed from the dataset locally. For instance, using `--episode-frame-range 0-0-10 3-5-20` will remove from episode 0, the frames from indices 0 to 10 excluded, and from episode 3 the frames from indices 5 to 20.", + ) + args = parser.parse_args() kwargs = vars(args) @@ -114,15 +174,15 @@ if __name__ == "__main__": private = kwargs.pop("private") == 1 dataset.push_to_hub(private=private, **kwargs) + elif mode == "clone": + dataset.clone(**kwargs) + + elif mode == "delete": + dataset.delete(**kwargs) + elif mode == "remove_episode": - remove_episode(**kwargs) + dataset.remove_episode(**kwargs) - elif mode == "delete_dataset": - delete_dataset() - - elif mode == "_episode": - - - - - \ No newline at end of file + elif mode == "drop_frame": + ep_range = parse_episode_range_strings(kwargs.pop("episode_range")) + dataset.drop_frame(episode_range=ep_range, **kwargs)