From d6571398280f7290b16a367e844e0fa677c3eebd Mon Sep 17 00:00:00 2001 From: Marina Barannikov Date: Thu, 6 Jun 2024 09:15:53 +0000 Subject: [PATCH] Updated comments --- examples/6_add_image_transforms.py | 13 +++++++++---- lerobot/scripts/visualize_image_transforms.py | 10 ---------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py index 956381ecf..be0eb9570 100644 --- a/examples/6_add_image_transforms.py +++ b/examples/6_add_image_transforms.py @@ -1,3 +1,8 @@ +""" +This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images. +The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned. +""" + from pathlib import Path from torchvision.transforms import ToPILImage, v2 @@ -6,7 +11,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset to_pil = ToPILImage() -# Create a directory to store the training checkpoint. +# Create a directory to store output images output_dir = Path("outputs/image_transforms") output_dir.mkdir(parents=True, exist_ok=True) @@ -15,10 +20,10 @@ repo_id = "lerobot/aloha_static_tape" # Create a LeRobotDataset with no transformations dataset = LeRobotDataset(repo_id, transform=None) -# Get the index of the first frame in the first episode +# Get the index of the first observation in the first episode first_idx = dataset.episode_data_index["from"][0].item() -# Get the frame from the first camera +# Get the frame corresponding to the first camera frame = dataset[first_idx][dataset.camera_keys[0]] # Save the original frame @@ -35,7 +40,7 @@ transforms = v2.Compose( ] ) -# Create a LeRobotDataset with the defined transformations +# Create another LeRobotDataset with the defined transformations transformed_dataset = LeRobotDataset(repo_id, transform=transforms) # Get a frame from the transformed dataset diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index 494488389..5b00bcabd 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -10,16 +10,6 @@ to_pil = ToPILImage() def main(cfg, output_dir=Path("outputs/image_transforms")): - """ - Function to apply image transforms from a configuration and save the transformed images. - - Args: - cfg (object): Configuration object containing the image transform settings and dataset_repo_id. - output_dir (str or Path, optional): Output directory to save the transformed images. Defaults to "outputs/image_transforms". - - Returns: - None - """ dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)