WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -1,20 +1,13 @@
from pathlib import Path
from typing import Callable
import einops
import numpy as np
import pygame
import pymunk
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractDataset
from lerobot.common.datasets.utils import download_and_extract_zip
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
@@ -83,37 +76,82 @@ def add_tee(
return body
class PushtDataset(AbstractDataset):
class PushtDataset(torch.utils.data.Dataset):
"""
Arguments
----------
delta_timestamps : dict[list[float]] | None, optional
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
"""
available_datasets = ["pusht"]
fps = 10
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__(
dataset_id,
version,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
@@ -147,8 +185,10 @@ class PushtDataset(AbstractDataset):
states = torch.from_numpy(dataset_dict["state"])
actions = torch.from_numpy(dataset_dict["action"])
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
idxtd = 0
for episode_id in tqdm.tqdm(range(num_episodes)):
idx1 = dataset_dict.meta["episode_ends"][episode_id]
# to create test artifact
@@ -194,30 +234,45 @@ class PushtDataset(AbstractDataset):
# last step of demonstration is considered done
done[-1] = True
ep_td = TensorDict(
{
("observation", "image"): image[:-1],
("observation", "state"): agent_pos[:-1],
"action": actions[idx0:idx1][:-1],
"episode": episode_ids[idx0:idx1][:-1],
"frame_id": torch.arange(0, num_frames - 1, 1),
("next", "observation", "image"): image[1:],
("next", "observation", "state"): agent_pos[1:],
# TODO: verify that reward and done are aligned with image and agent_pos
("next", "reward"): reward[1:],
("next", "done"): done[1:],
("next", "success"): success[1:],
},
batch_size=num_frames - 1,
)
ep_dict = {
"observation.image": image,
"observation.state": agent_pos,
"action": actions[idx0:idx1],
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.image": image[1:],
# "next.observation.state": agent_pos[1:],
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
"next.done": torch.cat([done[1:], done[[-1]]]),
"next.success": torch.cat([success[1:], success[[-1]]]),
}
ep_dicts.append(ep_dict)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
td_data[idxtd : idxtd + len(ep_td)] = ep_td
assert isinstance(episode_id, int)
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[episode_id]) == num_frames
idx0 = idx1
idxtd = idxtd + len(ep_td)
return TensorStorage(td_data.lock_())
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)
if __name__ == "__main__":
dataset = PushtDataset(
"pusht",
root=Path("data"),
delta_timestamps={
"observation.image": [0, -1, -0.2, -0.1],
"observation.state": [0, -1, -0.2, -0.1],
"action": [-0.1, 0, 1, 2, 3],
},
)
dataset[10]