Refactor crop_dataset_roi
This commit is contained in:
@@ -150,6 +150,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
@@ -183,43 +184,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# 2. Process each episode in the original dataset.
|
||||
episodes_info = original_dataset.meta.episodes
|
||||
# (Sort episodes by episode_index for consistency.)
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
last_episode_index = 0
|
||||
for sample in tqdm(original_dataset):
|
||||
episode_index = sample.pop("episode_index")
|
||||
if episode_index != last_episode_index:
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
last_episode_index = episode_index
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
|
||||
# save last episode
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
new_frame[key] = value
|
||||
|
||||
# Optionally, consolidate the new dataset to compute statistics and update video info.
|
||||
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
|
||||
new_dataset.add_frame(new_frame)
|
||||
|
||||
new_dataset.push_to_hub(tags=None)
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
@@ -244,10 +241,15 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
@@ -273,6 +275,7 @@ if __name__ == "__main__":
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# ruff: noqa: N806, N815, N803
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user