HF datasets works
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
@@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
@@ -7,12 +8,15 @@ from torchvision.transforms import v2
|
||||
from lerobot.common.datasets.utils import compute_stats
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
from lerobot.common.datasets.xarm import XarmDataset
|
||||
@@ -57,6 +61,8 @@ def make_dataset(
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_stats(stats_dataset)
|
||||
@@ -86,6 +92,8 @@ def make_dataset(
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -19,15 +21,24 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
@@ -51,6 +62,15 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user