HF datasets works

This commit is contained in:
Cadene
2024-04-16 12:20:38 +00:00
parent 5edd9a89a0
commit 0980fff6cc
42 changed files with 630 additions and 87 deletions

View File

@@ -1,5 +1,7 @@
from pathlib import Path
import torch
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from lerobot.common.datasets.utils import load_previous_and_future_frames
@@ -25,15 +27,24 @@ class AlohaDataset(torch.utils.data.Dataset):
self,
dataset_id: str,
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
@property
@@ -57,6 +68,15 @@ class AlohaDataset(torch.utils.data.Dataset):
self.delta_timestamps,
)
# convert images from channel last (PIL) to channel first (pytorch)
for key in self.image_keys:
if item[key].ndim == 3:
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
elif item[key].ndim == 4:
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
else:
raise ValueError(item[key].ndim)
if self.transform is not None:
item = self.transform(item)