diff --git a/src/lerobot/rl/crop_dataset_roi.py b/src/lerobot/rl/crop_dataset_roi.py index 281069e1..4345fed3 100644 --- a/src/lerobot/rl/crop_dataset_roi.py +++ b/src/lerobot/rl/crop_dataset_roi.py @@ -160,7 +160,7 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset): return image_dict -def convert_lerobot_dataset_to_cropper_lerobot_dataset( +def convert_lerobot_dataset_to_cropped_lerobot_dataset( original_dataset: LeRobotDataset, crop_params_dict: dict[str, tuple[int, int, int, int]], new_repo_id: str, @@ -190,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( # 1. Create a new (empty) LeRobotDataset for writing. new_dataset = LeRobotDataset.create( repo_id=new_repo_id, - fps=original_dataset.fps, + fps=int(original_dataset.fps), root=new_dataset_root, robot_type=original_dataset.meta.robot_type, features=original_dataset.meta.info["features"], @@ -275,6 +275,12 @@ if __name__ == "__main__": default="", help="The natural language task to describe the dataset.", ) + parser.add_argument( + "--new-repo-id", + type=str, + default=None, + help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.", + ) args = parser.parse_args() dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) @@ -294,10 +300,16 @@ if __name__ == "__main__": for key, roi in rois.items(): print(f"{key}: {roi}") - new_repo_id = args.repo_id + "_cropped_resized" - new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized" - cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + if args.new_repo_id: + new_dataset_name = args.new_repo_id.split("/")[-1] + # Parent 1: HF user, Parent 2: HF LeRobot Home + new_dataset_root = dataset.root.parent.parent / new_dataset_name + else: + new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + + cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset( original_dataset=dataset, crop_params_dict=rois, new_repo_id=new_repo_id,