diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py index cad3419a..a6afcbfb 100644 --- a/lerobot/scripts/server/crop_dataset_roi.py +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -150,6 +150,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( new_repo_id: str, new_dataset_root: str, resize_size: Tuple[int, int] = (128, 128), + push_to_hub: bool = False, ) -> LeRobotDataset: """ Converts an existing LeRobotDataset by iterating over its episodes and frames, @@ -183,43 +184,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset( # (Here we simply set the shape to be the final resize_size.) for key in crop_params_dict: if key in new_dataset.meta.info["features"]: - new_dataset.meta.info["features"][key]["shape"] = list(resize_size) + new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) - # 2. Process each episode in the original dataset. - episodes_info = original_dataset.meta.episodes - # (Sort episodes by episode_index for consistency.) + prev_episode_index = 0 + for frame_idx in tqdm(range(len(original_dataset))): + frame = original_dataset[frame_idx] - episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"]) - # Use the first task from the episode metadata (or "unknown" if not provided) - task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown" + # Create a copy of the frame to add to the new dataset + new_frame = {} + for key, value in frame.items(): + if key in ("task_index", "timestamp", "episode_index", "frame_index", "index"): + continue + if key in ("next.done", "next.reward"): + # if not isinstance(value, str) and len(value.shape) == 0: + value = value.unsqueeze(0) - last_episode_index = 0 - for sample in tqdm(original_dataset): - episode_index = sample.pop("episode_index") - if episode_index != last_episode_index: - new_dataset.save_episode(task, encode_videos=True) - last_episode_index = episode_index - sample.pop("frame_index") - # Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable) - new_sample = sample.copy() - # Loop over each observation key that should be cropped/resized. - for key, params in crop_params_dict.items(): - if key in new_sample: - top, left, height, width = params + if key in crop_params_dict: + top, left, height, width = crop_params_dict[key] # Apply crop then resize. - cropped = F.crop(new_sample[key], top, left, height, width) - resized = F.resize(cropped, resize_size) - new_sample[key] = resized - # Add the transformed frame to the new dataset. - new_dataset.add_frame(new_sample) + cropped = F.crop(value, top, left, height, width) + value = F.resize(cropped, resize_size) + value = value.clamp(0, 1) - # save last episode - new_dataset.save_episode(task, encode_videos=True) + new_frame[key] = value - # Optionally, consolidate the new dataset to compute statistics and update video info. - new_dataset.consolidate(run_compute_stats=True, keep_image_files=True) + new_dataset.add_frame(new_frame) - new_dataset.push_to_hub(tags=None) + if frame["episode_index"].item() != prev_episode_index: + # Save the episode + new_dataset.save_episode() + prev_episode_index = frame["episode_index"].item() + + if push_to_hub: + new_dataset.push_to_hub() return new_dataset @@ -244,10 +241,15 @@ if __name__ == "__main__": default=None, help="The path to the JSON file containing the ROIs.", ) + parser.add_argument( + "--push-to-hub", + type=bool, + default=False, + help="Whether to push the new dataset to the hub.", + ) args = parser.parse_args() - local_files_only = args.root is not None - dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only) + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) images = get_image_from_lerobot_dataset(dataset) images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} @@ -273,6 +275,7 @@ if __name__ == "__main__": new_repo_id=new_repo_id, new_dataset_root=new_dataset_root, resize_size=(128, 128), + push_to_hub=args.push_to_hub, ) meta_dir = new_dataset_root / "meta" diff --git a/lerobot/scripts/server/kinematics.py b/lerobot/scripts/server/kinematics.py index fb16548a..1928a18a 100644 --- a/lerobot/scripts/server/kinematics.py +++ b/lerobot/scripts/server/kinematics.py @@ -1,3 +1,5 @@ +# ruff: noqa: N806, N815, N803 + import numpy as np from scipy.spatial.transform import Rotation