Updated comments

This commit is contained in:
Marina Barannikov
2024-06-06 09:15:53 +00:00
parent b1714803a3
commit d657139828
2 changed files with 9 additions and 14 deletions

View File

@@ -1,3 +1,8 @@
"""
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
"""
from pathlib import Path
from torchvision.transforms import ToPILImage, v2
@@ -6,7 +11,7 @@ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
to_pil = ToPILImage()
# Create a directory to store the training checkpoint.
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")
output_dir.mkdir(parents=True, exist_ok=True)
@@ -15,10 +20,10 @@ repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, transform=None)
# Get the index of the first frame in the first episode
# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame from the first camera
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
# Save the original frame
@@ -35,7 +40,7 @@ transforms = v2.Compose(
]
)
# Create a LeRobotDataset with the defined transformations
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, transform=transforms)
# Get a frame from the transformed dataset