Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -1,78 +0,0 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
# Copied from lerobot/__init__.py
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
fps = 50
image_keys = ["observation.images.top"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@@ -1,9 +1,12 @@
import logging
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -11,22 +14,10 @@ def make_dataset(
cfg,
split="train",
):
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtDataset
clsfunc = PushtDataset
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaDataset
clsfunc = AlohaDataset
else:
raise ValueError(cfg.env.name)
if cfg.env.name not in cfg.dataset.repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
)
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
@@ -36,8 +27,8 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = clsfunc(
dataset_id=cfg.dataset_id,
dataset = LeRobotDataset(
cfg.dataset.repo_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,

View File

@@ -1,36 +1,21 @@
from pathlib import Path
import datasets
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
repo_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
@@ -38,16 +23,25 @@ class XarmDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.repo_id = repo_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
@property
def fps(self) -> int:
return self.info["fps"]
@property
def image_keys(self) -> list[str]:
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
@property
def num_samples(self) -> int:

View File

@@ -1,76 +0,0 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/pusht
Arguments
----------
delta_timestamps : dict[list[float]] | None, optional
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
"""
# Copied from lerobot/__init__.py
available_datasets = ["pusht"]
fps = 10
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@@ -1,3 +1,4 @@
import json
from copy import deepcopy
from math import ceil
from pathlib import Path
@@ -15,7 +16,7 @@ from torchvision import transforms
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
@@ -61,19 +62,17 @@ def hf_transform_to_torch(items_dict):
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
@@ -84,9 +83,8 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
@@ -94,7 +92,7 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
@@ -103,15 +101,32 @@ def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
with open(path) as f:
info = json.load(f)
return info
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,