Save Cropped Dataset to Hub (#2071)
* fix: cast fps argument from dataset to int * fix: typo * fix: specify repo-id
This commit is contained in:
committed by
GitHub
parent
5b647e3bcb
commit
e3b572992e
@@ -160,7 +160,7 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
|||||||
return image_dict
|
return image_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
def convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||||
original_dataset: LeRobotDataset,
|
original_dataset: LeRobotDataset,
|
||||||
crop_params_dict: dict[str, tuple[int, int, int, int]],
|
crop_params_dict: dict[str, tuple[int, int, int, int]],
|
||||||
new_repo_id: str,
|
new_repo_id: str,
|
||||||
@@ -190,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
|||||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||||
new_dataset = LeRobotDataset.create(
|
new_dataset = LeRobotDataset.create(
|
||||||
repo_id=new_repo_id,
|
repo_id=new_repo_id,
|
||||||
fps=original_dataset.fps,
|
fps=int(original_dataset.fps),
|
||||||
root=new_dataset_root,
|
root=new_dataset_root,
|
||||||
robot_type=original_dataset.meta.robot_type,
|
robot_type=original_dataset.meta.robot_type,
|
||||||
features=original_dataset.meta.info["features"],
|
features=original_dataset.meta.info["features"],
|
||||||
@@ -275,6 +275,12 @@ if __name__ == "__main__":
|
|||||||
default="",
|
default="",
|
||||||
help="The natural language task to describe the dataset.",
|
help="The natural language task to describe the dataset.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--new-repo-id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||||
@@ -294,10 +300,16 @@ if __name__ == "__main__":
|
|||||||
for key, roi in rois.items():
|
for key, roi in rois.items():
|
||||||
print(f"{key}: {roi}")
|
print(f"{key}: {roi}")
|
||||||
|
|
||||||
new_repo_id = args.repo_id + "_cropped_resized"
|
new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized"
|
||||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
|
||||||
|
|
||||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
if args.new_repo_id:
|
||||||
|
new_dataset_name = args.new_repo_id.split("/")[-1]
|
||||||
|
# Parent 1: HF user, Parent 2: HF LeRobot Home
|
||||||
|
new_dataset_root = dataset.root.parent.parent / new_dataset_name
|
||||||
|
else:
|
||||||
|
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||||
|
|
||||||
|
cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset(
|
||||||
original_dataset=dataset,
|
original_dataset=dataset,
|
||||||
crop_params_dict=rois,
|
crop_params_dict=rois,
|
||||||
new_repo_id=new_repo_id,
|
new_repo_id=new_repo_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user