Improve dataset examples (#82)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -40,31 +40,31 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_id: str = "pusht",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
@@ -38,31 +38,31 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
@@ -8,7 +9,7 @@ import tqdm
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
data_dict: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float = 0.04,
|
||||
) -> dict[torch.Tensor]:
|
||||
@@ -24,7 +25,7 @@ def load_previous_and_future_frames(
|
||||
|
||||
Parameters:
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||
|
||||
@@ -40,7 +41,7 @@ def load_previous_and_future_frames(
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
@@ -70,7 +71,7 @@ def load_previous_and_future_frames(
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# load frames modality
|
||||
item[key] = data_dict.select_columns(key)[data_ids][key]
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
||||
@@ -19,7 +19,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
dataset_id: str = "xarm_lift_medium",
|
||||
version: str | None = "v1.0",
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
@@ -34,31 +34,31 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
if self.root is not None:
|
||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||
else:
|
||||
self.data_dict = load_dataset(
|
||||
self.hf_dataset = load_dataset(
|
||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||
)
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict)
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
return len(self.hf_dataset.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_dict[idx]
|
||||
item = self.hf_dataset[idx]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.hf_dataset,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ def eval_policy(
|
||||
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
||||
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||
@@ -292,7 +292,7 @@ def eval_policy(
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
},
|
||||
"episodes": data_dict,
|
||||
"episodes": hf_dataset,
|
||||
}
|
||||
if max_episodes_rendered > 0:
|
||||
info["videos"] = videos
|
||||
|
||||
@@ -2,10 +2,11 @@ import logging
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils.logging import disable_progress_bar
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
@@ -130,15 +131,40 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||
|
||||
|
||||
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
|
||||
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_index = data_dict.select_columns("index")[0]["index"].item()
|
||||
def add_episodes_inplace(
|
||||
online_dataset: torch.utils.data.Dataset,
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
||||
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
||||
dataset's structure and adjusting the sampling strategy based on the specified
|
||||
percentage of online samples.
|
||||
|
||||
Parameters:
|
||||
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||
offline and online datasets, used for sampling purposes.
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
|
||||
if len(online_dataset) == 0:
|
||||
# initialize online dataset
|
||||
online_dataset.data_dict = data_dict
|
||||
online_dataset.hf_dataset = hf_dataset
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
@@ -152,11 +178,12 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bar() # map has a tqdm progress bar
|
||||
data_dict = data_dict.map(shift_indices)
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices)
|
||||
enable_progress_bars()
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
# update the concatenated dataset length used during sampling
|
||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||
@@ -274,7 +301,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.data_dict = {}
|
||||
online_dataset.hf_dataset = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
@@ -308,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
|
||||
Reference in New Issue
Block a user