Improve dataset examples (#82)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-18 11:43:16 +02:00
committed by GitHub
parent d5c4b0c344
commit 0928afd37d
15 changed files with 274 additions and 165 deletions

View File

@@ -0,0 +1,59 @@
"""
This script demonstrates the visualization of various robotic datasets from Hugging Face hub.
It covers the steps from loading the datasets, filtering specific episodes, and converting the frame data to MP4 videos.
Importantly, the dataset format is agnostic to any deep learning library and doesn't require using `lerobot` functions.
It is compatible with pytorch, jax, numpy, etc.
As an example, this script saves frames of episode number 5 of the PushT dataset to a mp4 video and saves the result here:
`outputs/examples/1_visualize_hugging_face_datasets/episode_5.mp4`
This script supports several Hugging Face datasets, among which:
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
To try a different Hugging Face dataset, you can replace this line:
```python
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
```
by one of these:
```python
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
```
"""
from pathlib import Path
import imageio
from datasets import load_dataset
# TODO(rcadene): list available datasets on lerobot page using `datasets`
# download/load hugging face dataset in pyarrow format
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
# display name of dataset and its features
print(f"{hf_dataset=}")
print(f"{hf_dataset.features=}")
# display useful statistics about frames and episodes, which are sequences of frames from the same video
print(f"number of frames: {len(hf_dataset)=}")
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
# select the frames belonging to episode number 5
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
# load all frames of episode 5 in RAM in PIL format
frames = hf_dataset["observation.image"]
# save episode frames to a mp4 video
Path("outputs/examples/1_load_hugging_face_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4", frames, fps=fps)

View File

@@ -1,20 +0,0 @@
import os
from pathlib import Path
import lerobot
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# TODO(rcadene): remove DATA_DIR
dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']

View File

@@ -0,0 +1,98 @@
"""
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script:
- Loading a dataset and accessing its properties.
- Filtering data by episode number.
- Converting tensor data for visualization.
- Saving video files from dataset frames.
- Using advanced dataset features like timestamp-based frame selection.
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
The script ends with examples of how to batch process data using PyTorch's DataLoader.
To try a different Hugging Face dataset, you can replace:
```python
dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset()
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
```
"""
from pathlib import Path
import imageio
import torch
from lerobot.common.datasets.pusht import PushtDataset
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
# print("List of available datasets", lerobot.available_datasets)
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
# # 'pusht', 'xarm_lift_medium']
# You can easily load datasets from LeRobot
dataset = PushtDataset()
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
print(f"{dataset=}")
print(f"{dataset.hf_dataset=}")
# and provide additional utilities for robotics and compatibility with pytorch
print(f"number of samples/frames: {dataset.num_samples=}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
frames = [sample["observation.image"] for sample in dataset]
# but frames are now channel first to follow pytorch convention,
# to view them, we convert to channel last
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
# and finally save them to a mp4 video
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
# using timestamps differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
dataset = PushtDataset(delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
# because they are just PyTorch datasets.
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=32,
shuffle=True,
)
for batch in dataloader:
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
print(f"{batch['observation.state'].shape=}") # (32,8,c)
print(f"{batch['action'].shape=}") # (32,64,c)
break