Add MultiLerobotDataset for training with multiple LeRobotDatasets (#229)

This commit is contained in:
Alexander Soare
2024-05-30 16:12:21 +01:00
committed by GitHub
parent 265b0ec44d
commit 111cd58f8a
8 changed files with 352 additions and 72 deletions

View File

@@ -16,9 +16,9 @@
import logging
import torch
from omegaconf import OmegaConf
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
def resolve_delta_timestamps(cfg):
@@ -35,11 +35,27 @@ def resolve_delta_timestamps(cfg):
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id:
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
@@ -49,11 +65,16 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():