forked from tangger/lerobot
Refactor datasets into LeRobotDataset (#91)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
|
||||
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
@@ -11,22 +11,6 @@ Features included in this script:
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
|
||||
To try a different Hugging Face dataset, you can replace:
|
||||
```python
|
||||
dataset = PushtDataset()
|
||||
```
|
||||
by one of these:
|
||||
```python
|
||||
dataset = XarmDataset("xarm_lift_medium")
|
||||
dataset = XarmDataset("xarm_lift_medium_replay")
|
||||
dataset = XarmDataset("xarm_push_medium")
|
||||
dataset = XarmDataset("xarm_push_medium_replay")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
|
||||
```
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -34,31 +18,33 @@ from pathlib import Path
|
||||
import imageio
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
|
||||
# print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
|
||||
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
|
||||
# # 'pusht', 'xarm_lift_medium']
|
||||
print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
|
||||
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
|
||||
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
|
||||
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
# You can easily load datasets from LeRobot
|
||||
dataset = PushtDataset()
|
||||
# You can easily load a dataset from a Hugging Face repositery
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
|
||||
# and provide additional utilities for robotics and compatibility with pytorch
|
||||
# and provides additional utilities for robotics and compatibility with pytorch
|
||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
||||
print(f"number of episodes: {dataset.num_episodes=}")
|
||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
|
||||
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
|
||||
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||
# TODO(rcadene): remove this example of accessing hf_dataset
|
||||
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
@@ -85,7 +71,7 @@ delta_timestamps = {
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
dataset = PushtDataset(delta_timestamps=delta_timestamps)
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
||||
|
||||
Reference in New Issue
Block a user