Refactor datasets into LeRobotDataset (#91)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
@@ -11,22 +14,10 @@ def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
|
||||
clsfunc = XarmDataset
|
||||
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.datasets.pusht import PushtDataset
|
||||
|
||||
clsfunc = PushtDataset
|
||||
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.datasets.aloha import AlohaDataset
|
||||
|
||||
clsfunc = AlohaDataset
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
if cfg.env.name not in cfg.dataset.repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
@@ -36,8 +27,8 @@ def make_dataset(
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
|
||||
Reference in New Issue
Block a user