HF datasets works

This commit is contained in:
Cadene
2024-04-16 12:20:38 +00:00
parent 5edd9a89a0
commit 0980fff6cc
42 changed files with 630 additions and 87 deletions

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
@property
@@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
import torch
@@ -7,12 +8,15 @@ from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
def make_dataset(
cfg,
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
normalize=True,
stats_path=None,
split="train",
):
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
@@ -57,6 +61,8 @@ def make_dataset(
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
@@ -86,6 +92,8 @@ def make_dataset(
dataset = clsfunc(
dataset_id=cfg.dataset_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,
transform=transforms,
)

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
@property
def num_samples(self) -> int:
@@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -19,15 +21,24 @@ class XarmDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
@property
@@ -51,6 +62,15 @@ class XarmDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)