Add docs for LeRobot Image transforms (#1972)
* Remove unused scripts, add docs for image transforms and add example * fix(examples): move train_policy.py under examples, remove outdated readme parts * remove script thats copied to train folder * remove outdated links to examples and example tests
This commit is contained in:
148
examples/dataset/load_lerobot_dataset.py
Normal file
148
examples/dataset/load_lerobot_dataset.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
- Viewing a dataset's metadata and exploring its properties.
|
||||
- Loading an existing dataset from the hub or a subset of it.
|
||||
- Accessing frames by episode number.
|
||||
- Using advanced dataset features like timestamp-based frame selection.
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
177
examples/dataset/use_dataset_image_transforms.py
Normal file
177
examples/dataset/use_dataset_image_transforms.py
Normal file
@@ -0,0 +1,177 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This example demonstrates how to use image transforms with LeRobot datasets for data augmentation during training.
|
||||
|
||||
Image transforms are applied to camera frames to improve model robustness and generalization. They are applied
|
||||
at training time only, not during dataset recording, allowing you to experiment with different augmentations
|
||||
without re-recording data.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.functional import to_pil_image
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.transforms import ImageTransformConfig, ImageTransforms, ImageTransformsConfig
|
||||
|
||||
|
||||
def save_image(tensor, filename):
|
||||
"""Helper function to save a tensor as an image file."""
|
||||
if tensor.dim() == 3: # [C, H, W]
|
||||
if tensor.max() > 1.0:
|
||||
tensor = tensor / 255.0
|
||||
tensor = torch.clamp(tensor, 0.0, 1.0)
|
||||
pil_image = to_pil_image(tensor)
|
||||
pil_image.save(filename)
|
||||
print(f"Saved: {filename}")
|
||||
else:
|
||||
print(f"Skipped {filename}: unexpected tensor shape {tensor.shape}")
|
||||
|
||||
|
||||
def example_1_default_transforms():
|
||||
"""Example 1: Use default transform configuration and save original vs transformed images"""
|
||||
print("\n Example 1: Default Transform Configuration with Image Saving")
|
||||
|
||||
repo_id = "pepijn223/record_main_0" # Example dataset
|
||||
|
||||
try:
|
||||
# Load dataset without transforms (original)
|
||||
dataset_original = LeRobotDataset(repo_id=repo_id)
|
||||
|
||||
# Load dataset with transforms enabled
|
||||
transforms_config = ImageTransformsConfig(
|
||||
enable=True, # Enable transforms (disabled by default)
|
||||
max_num_transforms=2, # Apply up to 2 transforms per frame
|
||||
random_order=False, # Apply in standard order
|
||||
)
|
||||
dataset_with_transforms = LeRobotDataset(
|
||||
repo_id=repo_id, image_transforms=ImageTransforms(transforms_config)
|
||||
)
|
||||
|
||||
# Save original and transformed images for comparison
|
||||
if len(dataset_original) > 0:
|
||||
frame_idx = 0 # Use first frame
|
||||
original_sample = dataset_original[frame_idx]
|
||||
transformed_sample = dataset_with_transforms[frame_idx]
|
||||
|
||||
print(f"Saving comparison images (frame {frame_idx}):")
|
||||
|
||||
for cam_key in dataset_original.meta.camera_keys:
|
||||
if cam_key in original_sample and cam_key in transformed_sample:
|
||||
cam_name = cam_key.replace(".", "_").replace("/", "_")
|
||||
|
||||
# Save original and transformed images
|
||||
save_image(original_sample[cam_key], f"{cam_name}_original.png")
|
||||
save_image(transformed_sample[cam_key], f"{cam_name}_transformed.png")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not load dataset '{repo_id}': {e}")
|
||||
|
||||
|
||||
def example_2_custom_transforms():
|
||||
"""Example 2: Create custom transform configuration and save examples"""
|
||||
print("\n Example 2: Custom Transform Configuration")
|
||||
|
||||
repo_id = "pepijn223/record_main_0" # Example dataset
|
||||
|
||||
try:
|
||||
# Create custom transform configuration with strong effects
|
||||
custom_transforms_config = ImageTransformsConfig(
|
||||
enable=True,
|
||||
max_num_transforms=2, # Apply up to 2 transforms per frame
|
||||
random_order=True, # Apply transforms in random order
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 1.5)}, # Strong brightness range
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0, # Higher weight = more likely to be selected
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.6, 1.4)}, # Strong contrast
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=0.5, # Lower weight = less likely to be selected
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.2, 2.0)}, # Strong sharpness variation
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
dataset_with_custom_transforms = LeRobotDataset(
|
||||
repo_id=repo_id, image_transforms=ImageTransforms(custom_transforms_config)
|
||||
)
|
||||
|
||||
# Save examples with strong transforms
|
||||
if len(dataset_with_custom_transforms) > 0:
|
||||
sample = dataset_with_custom_transforms[0]
|
||||
print("Saving custom transform examples:")
|
||||
|
||||
for cam_key in dataset_with_custom_transforms.meta.camera_keys:
|
||||
if cam_key in sample:
|
||||
cam_name = cam_key.replace(".", "_").replace("/", "_")
|
||||
save_image(sample[cam_key], f"{cam_name}_custom_transforms.png")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not load dataset '{repo_id}': {e}")
|
||||
|
||||
|
||||
def example_3_torchvision_transforms():
|
||||
"""Example 3: Use pure torchvision transforms and save examples"""
|
||||
print("\n Example 3: Pure Torchvision Transforms")
|
||||
|
||||
repo_id = "pepijn223/record_main_0" # Example dataset
|
||||
|
||||
try:
|
||||
# Create torchvision transform pipeline
|
||||
torchvision_transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
||||
v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
||||
v2.RandomRotation(degrees=10), # Small rotation
|
||||
]
|
||||
)
|
||||
|
||||
dataset_with_torchvision = LeRobotDataset(repo_id=repo_id, image_transforms=torchvision_transforms)
|
||||
|
||||
# Save examples with torchvision transforms
|
||||
if len(dataset_with_torchvision) > 0:
|
||||
sample = dataset_with_torchvision[0]
|
||||
print("Saving torchvision transform examples:")
|
||||
|
||||
for cam_key in dataset_with_torchvision.meta.camera_keys:
|
||||
if cam_key in sample:
|
||||
cam_name = cam_key.replace(".", "_").replace("/", "_")
|
||||
save_image(sample[cam_key], f"{cam_name}_torchvision.png")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Could not load dataset '{repo_id}': {e}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all examples"""
|
||||
print("LeRobot Dataset Image Transforms Examples")
|
||||
|
||||
example_1_default_transforms()
|
||||
example_2_custom_transforms()
|
||||
example_3_torchvision_transforms()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user