For Pusht: use hf datasets to train, rename load_data_with_delta_timestamps -> load_previous_and_future_frames

This commit is contained in:
Cadene
2024-04-15 10:08:10 +00:00
parent 4ed55c3ba3
commit c6aca7fe44
5 changed files with 920 additions and 113 deletions

View File

@@ -4,11 +4,12 @@ import einops
import numpy as np
import torch
import tqdm
from datasets import Dataset, load_dataset
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
ReplayBuffer as DiffusionPolicyReplayBuffer,
)
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.datasets.utils import download_and_extract_zip, load_previous_and_future_frames
# as define in env
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
@@ -46,50 +47,41 @@ class PushtDataset(torch.utils.data.Dataset):
self.transform = transform
self.delta_timestamps = delta_timestamps
self.data_dir = self.root / f"{self.dataset_id}"
if (self.data_dir / "data_dict.pth").exists() and (
self.data_dir / "data_ids_per_episode.pth"
).exists():
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
self.data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
# self.data_dir = self.root / f"{self.dataset_id}"
# if (self.data_dir / "data_dict.pth").exists() and (
# self.data_dir / "data_ids_per_episode.pth"
# ).exists():
# self.data_dict = torch.load(self.data_dir / "data_dict.pth")
# self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
# else:
# self._download_and_preproc_obsolete()
# self.data_dir.mkdir(parents=True, exist_ok=True)
# torch.save(self.data_dict, self.data_dir / "data_dict.pth")
# torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
self.data_dict = load_dataset("lerobot/pusht", split="train")
self.data_dict = self.data_dict.with_format("torch")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
return len(self.data_dict)
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
return len(self.data_dict.unique("episode_id"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
item = self.data_dict[idx]
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.data_dict,
self.delta_timestamps,
)
if self.transform is not None:
item = self.transform(item)
@@ -214,3 +206,16 @@ class PushtDataset(torch.utils.data.Dataset):
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)
dataset = Dataset.from_dict(self.data_dict)
dataset = dataset.with_format("torch")
def add_episode_data_id_from_to(frame):
ep_id = frame["episode"].item()
frame["episode_data_id_from"] = self.data_ids_per_episode[ep_id][0]
frame["episode_data_id_to"] = self.data_ids_per_episode[ep_id][-1]
return frame
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
dataset = dataset.rename_column("episode", "episode_id")
dataset.push_to_hub("lerobot/pusht", token=True)