Improve dataset examples (#82)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-18 11:43:16 +02:00
committed by GitHub
parent d5c4b0c344
commit 0928afd37d
15 changed files with 274 additions and 165 deletions

View File

@@ -40,31 +40,31 @@ class AlohaDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
self.hf_dataset = self.hf_dataset.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict)
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.data_dict.unique("episode_id"))
return len(self.hf_dataset.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.data_dict[idx]
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.hf_dataset,
self.delta_timestamps,
)

View File

@@ -23,7 +23,7 @@ class PushtDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
dataset_id: str = "pusht",
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
@@ -38,31 +38,31 @@ class PushtDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
self.hf_dataset = self.hf_dataset.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict)
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.data_dict.unique("episode_id"))
return len(self.hf_dataset.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.data_dict[idx]
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.hf_dataset,
self.delta_timestamps,
)

View File

@@ -1,6 +1,7 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
@@ -8,7 +9,7 @@ import tqdm
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
data_dict: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
delta_timestamps: dict[str, list[float]],
tol: float = 0.04,
) -> dict[torch.Tensor]:
@@ -24,7 +25,7 @@ def load_previous_and_future_frames(
Parameters:
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
@@ -40,7 +41,7 @@ def load_previous_and_future_frames(
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
@@ -70,7 +71,7 @@ def load_previous_and_future_frames(
data_ids = ep_data_ids[argmin_]
# load frames modality
item[key] = data_dict.select_columns(key)[data_ids][key]
item[key] = hf_dataset.select_columns(key)[data_ids][key]
item[f"{key}_is_pad"] = is_pad
return item

View File

@@ -19,7 +19,7 @@ class XarmDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
dataset_id: str = "xarm_lift_medium",
version: str | None = "v1.0",
root: Path | None = None,
split: str = "train",
@@ -34,31 +34,31 @@ class XarmDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
if self.root is not None:
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
else:
self.data_dict = load_dataset(
self.hf_dataset = load_dataset(
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
)
self.data_dict = self.data_dict.with_format("torch")
self.hf_dataset = self.hf_dataset.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict)
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.data_dict.unique("episode_id"))
return len(self.hf_dataset.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.data_dict[idx]
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.hf_dataset,
self.delta_timestamps,
)