WIP
WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
@@ -1,20 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
@@ -83,37 +76,82 @@ def add_tee(
|
||||
return body
|
||||
|
||||
|
||||
class PushtDataset(AbstractDataset):
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
||||
Arguments
|
||||
----------
|
||||
delta_timestamps : dict[list[float]] | None, optional
|
||||
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
|
||||
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
||||
"""
|
||||
|
||||
available_datasets = ["pusht"]
|
||||
fps = 10
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
@@ -147,8 +185,10 @@ class PushtDataset(AbstractDataset):
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
@@ -194,30 +234,45 @@ class PushtDataset(AbstractDataset):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": actions[idx0:idx1][:-1],
|
||||
"episode": episode_ids[idx0:idx1][:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = PushtDataset(
|
||||
"pusht",
|
||||
root=Path("data"),
|
||||
delta_timestamps={
|
||||
"observation.image": [0, -1, -0.2, -0.1],
|
||||
"observation.state": [0, -1, -0.2, -0.1],
|
||||
"action": [-0.1, 0, 1, 2, 3],
|
||||
},
|
||||
)
|
||||
dataset[10]
|
||||
|
||||
Reference in New Issue
Block a user