HF datasets works
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -7,12 +8,15 @@ from torchvision.transforms import v2
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
@@ -57,6 +61,8 @@ def make_dataset(
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_stats(stats_dataset)
|
||||
@@ -86,6 +92,8 @@ def make_dataset(
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user