forked from tangger/lerobot
Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -10,10 +10,13 @@ As an example, this script saves frames of episode number 5 of the PushT dataset
|
||||
This script supports several Hugging Face datasets, among which:
|
||||
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
||||
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
||||
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
3. [Xarm Lift Medium Replay](https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay)
|
||||
4. [Xarm Push Medium](https://huggingface.co/datasets/lerobot/xarm_push_medium)
|
||||
5. [Xarm Push Medium Replay](https://huggingface.co/datasets/lerobot/xarm_push_medium_replay)
|
||||
6. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||
7. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||
8. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||
9. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||
|
||||
To try a different Hugging Face dataset, you can replace this line:
|
||||
```python
|
||||
@@ -22,12 +25,16 @@ hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
by one of these:
|
||||
```python
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium_replay", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/xarm_push_medium_replay", split="train"), 15
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
||||
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
||||
```
|
||||
"""
|
||||
# TODO(rcadene): remove this example file of using hf_dataset
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -37,19 +44,22 @@ from datasets import load_dataset
|
||||
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||
|
||||
# download/load hugging face dataset in pyarrow format
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", revision="v1.0", split="train"), 10
|
||||
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||
|
||||
# display name of dataset and its features
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{hf_dataset=}")
|
||||
print(f"{hf_dataset.features=}")
|
||||
|
||||
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
||||
print(f"number of frames: {len(hf_dataset)=}")
|
||||
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
|
||||
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
||||
print(f"number of episodes: {len(hf_dataset.unique('episode_index'))=}")
|
||||
print(
|
||||
f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_index')):.3f}"
|
||||
)
|
||||
|
||||
# select the frames belonging to episode number 5
|
||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
|
||||
|
||||
# load all frames of episode 5 in RAM in PIL format
|
||||
frames = hf_dataset["observation.image"]
|
||||
|
||||
Reference in New Issue
Block a user