For Pusht: use hf datasets to train, rename load_data_with_delta_timestamps -> load_previous_and_future_frames
This commit is contained in:
@@ -4,11 +4,12 @@ import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from lerobot.common.datasets._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_previous_and_future_frames
|
||||
|
||||
# as define in env
|
||||
SUCCESS_THRESHOLD = 0.95 # 95% coverage,
|
||||
@@ -46,50 +47,41 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
self.data_dir = self.root / f"{self.dataset_id}"
|
||||
if (self.data_dir / "data_dict.pth").exists() and (
|
||||
self.data_dir / "data_ids_per_episode.pth"
|
||||
).exists():
|
||||
self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
# self.data_dir = self.root / f"{self.dataset_id}"
|
||||
# if (self.data_dir / "data_dict.pth").exists() and (
|
||||
# self.data_dir / "data_ids_per_episode.pth"
|
||||
# ).exists():
|
||||
# self.data_dict = torch.load(self.data_dir / "data_dict.pth")
|
||||
# self.data_ids_per_episode = torch.load(self.data_dir / "data_ids_per_episode.pth")
|
||||
# else:
|
||||
# self._download_and_preproc_obsolete()
|
||||
# self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
# torch.save(self.data_dict, self.data_dir / "data_dict.pth")
|
||||
# torch.save(self.data_ids_per_episode, self.data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
self.data_dict = load_dataset("lerobot/pusht", split="train")
|
||||
self.data_dict = self.data_dict.with_format("torch")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"]) if "index" in self.data_dict else 0
|
||||
return len(self.data_dict)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
return len(self.data_dict.unique("episode_id"))
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
item = self.data_dict[idx]
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.data_dict,
|
||||
self.delta_timestamps,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
@@ -214,3 +206,16 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
dataset = Dataset.from_dict(self.data_dict)
|
||||
dataset = dataset.with_format("torch")
|
||||
|
||||
def add_episode_data_id_from_to(frame):
|
||||
ep_id = frame["episode"].item()
|
||||
frame["episode_data_id_from"] = self.data_ids_per_episode[ep_id][0]
|
||||
frame["episode_data_id_to"] = self.data_ids_per_episode[ep_id][-1]
|
||||
return frame
|
||||
|
||||
dataset = dataset.map(add_episode_data_id_from_to, num_proc=4)
|
||||
dataset = dataset.rename_column("episode", "episode_id")
|
||||
dataset.push_to_hub("lerobot/pusht", token=True)
|
||||
|
||||
Reference in New Issue
Block a user