Save Cropped Dataset to Hub (#2071)

* fix: cast fps argument from dataset to int

* fix: typo

* fix: specify repo-id
This commit is contained in:
Francesco Capuano
2025-09-27 16:07:53 +02:00
committed by GitHub
parent 5b647e3bcb
commit e3b572992e

View File

@@ -160,7 +160,7 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
def convert_lerobot_dataset_to_cropped_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: dict[str, tuple[int, int, int, int]],
new_repo_id: str,
@@ -190,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
fps=int(original_dataset.fps),
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
@@ -275,6 +275,12 @@ if __name__ == "__main__":
default="",
help="The natural language task to describe the dataset.",
)
parser.add_argument(
"--new-repo-id",
type=str,
default=None,
help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
@@ -294,10 +300,16 @@ if __name__ == "__main__":
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized"
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
if args.new_repo_id:
new_dataset_name = args.new_repo_id.split("/")[-1]
# Parent 1: HF user, Parent 2: HF LeRobot Home
new_dataset_root = dataset.root.parent.parent / new_dataset_name
else:
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,