Refactor crop_dataset_roi

This commit is contained in:
Michel Aractingi
2025-04-22 09:31:35 +02:00
parent a7a51cfc9c
commit dc726cb9a3
2 changed files with 37 additions and 32 deletions

View File

@@ -150,6 +150,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
new_repo_id: str,
new_dataset_root: str,
resize_size: Tuple[int, int] = (128, 128),
push_to_hub: bool = False,
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
@@ -183,43 +184,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# 2. Process each episode in the original dataset.
episodes_info = original_dataset.meta.episodes
# (Sort episodes by episode_index for consistency.)
prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx]
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
# Use the first task from the episode metadata (or "unknown" if not provided)
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
# Create a copy of the frame to add to the new dataset
new_frame = {}
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
continue
if key in ("next.done", "next.reward"):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)
last_episode_index = 0
for sample in tqdm(original_dataset):
episode_index = sample.pop("episode_index")
if episode_index != last_episode_index:
new_dataset.save_episode(task, encode_videos=True)
last_episode_index = episode_index
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
if key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
cropped = F.crop(value, top, left, height, width)
value = F.resize(cropped, resize_size)
value = value.clamp(0, 1)
# save last episode
new_dataset.save_episode(task, encode_videos=True)
new_frame[key] = value
# Optionally, consolidate the new dataset to compute statistics and update video info.
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
new_dataset.add_frame(new_frame)
new_dataset.push_to_hub(tags=None)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
new_dataset.save_episode()
prev_episode_index = frame["episode_index"].item()
if push_to_hub:
new_dataset.push_to_hub()
return new_dataset
@@ -244,10 +241,15 @@ if __name__ == "__main__":
default=None,
help="The path to the JSON file containing the ROIs.",
)
parser.add_argument(
"--push-to-hub",
type=bool,
default=False,
help="Whether to push the new dataset to the hub.",
)
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
@@ -273,6 +275,7 @@ if __name__ == "__main__":
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
push_to_hub=args.push_to_hub,
)
meta_dir = new_dataset_root / "meta"

View File

@@ -1,3 +1,5 @@
# ruff: noqa: N806, N815, N803
import numpy as np
from scipy.spatial.transform import Rotation