Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -1,3 +1,4 @@
import json
from copy import deepcopy
from math import ceil
from pathlib import Path
@@ -15,7 +16,7 @@ from torchvision import transforms
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
@@ -61,19 +62,17 @@ def hf_transform_to_torch(items_dict):
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
@@ -84,9 +83,8 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
@@ -94,7 +92,7 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
@@ -103,15 +101,32 @@ def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
with open(path) as f:
info = json.load(f)
return info
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,