Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 12:23:12 +02:00
committed by GitHub
parent e760e4cd63
commit 659c69a1c0
90 changed files with 167 additions and 352 deletions

View File

@@ -1,9 +1,12 @@
import logging
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@@ -11,22 +14,10 @@ def make_dataset(
cfg,
split="train",
):
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtDataset
clsfunc = PushtDataset
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaDataset
clsfunc = AlohaDataset
else:
raise ValueError(cfg.env.name)
if cfg.env.name not in cfg.dataset.repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
)
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
@@ -36,8 +27,8 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = clsfunc(
dataset_id=cfg.dataset_id,
dataset = LeRobotDataset(
cfg.dataset.repo_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,