HF datasets works
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
from lerobot.common.datasets.utils import load_previous_and_future_frames
|
||||
|
||||
@@ -23,18 +25,25 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", revision=self.version, split="train")
|
||||
self.data_dict = load_dataset(f"lerobot/{self.dataset_id}", split="train")
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.data_dict.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
@@ -57,6 +66,15 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
# convert images from channel last (PIL) to channel first (pytorch)
|
||||
for key in self.image_keys:
|
||||
if item[key].ndim == 3:
|
||||
item[key] = item[key].permute((2, 0, 1)) # h w c -> c h w
|
||||
elif item[key].ndim == 4:
|
||||
item[key] = item[key].permute((0, 3, 1, 2)) # t h w c -> t c h w
|
||||
else:
|
||||
raise ValueError(item[key].ndim)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user