WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -1,26 +1,13 @@
import logging
from pathlib import Path
from typing import Callable
import einops
import gdown
import h5py
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import Sampler
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractDataset
DATASET_IDS = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
FOLDER_URLS = {
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
@@ -66,7 +53,6 @@ CAMERAS = {
def download(data_dir, dataset_id):
assert dataset_id in DATASET_IDS
assert dataset_id in FOLDER_URLS
assert dataset_id in EP48_URLS
assert dataset_id in EP49_URLS
@@ -80,51 +66,78 @@ def download(data_dir, dataset_id):
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
class AlohaDataset(AbstractDataset):
available_datasets = DATASET_IDS
class AlohaDataset(torch.utils.data.Dataset):
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
fps = 50
image_keys = ["observation.images.top"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.2",
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__(
dataset_id,
version,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
@property
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> c",
("action",): "b c -> c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d
def num_samples(self) -> int:
return len(self.data_dict["index"])
@property
def image_keys(self) -> list:
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
assert self.root is not None
@@ -132,54 +145,55 @@ class AlohaDataset(AbstractDataset):
if not raw_dir.is_dir():
download(raw_dir, self.dataset_id)
total_num_frames = 0
total_frames = 0
logging.info("Compute total number of frames to initialize offline buffer")
for ep_id in range(NUM_EPISODES[self.dataset_id]):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
total_num_frames += ep["/action"].shape[0] - 1
logging.info(f"{total_num_frames=}")
total_frames += ep["/action"].shape[0] - 1
logging.info(f"{total_frames=}")
self.data_ids_per_episode = {}
ep_dicts = []
logging.info("Initialize and feed offline buffer")
idxtd = 0
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
with h5py.File(ep_path, "r") as ep:
ep_num_frames = ep["/action"].shape[0]
num_frames = ep["/action"].shape[0]
# last step of demonstration is considered done
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
done[-1] = True
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
ep_td = TensorDict(
{
("observation", "state"): state[:-1],
"action": action[:-1],
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
("next", "observation", "state"): state[1:],
# TODO: compute reward and success
# ("next", "reward"): reward[1:],
("next", "done"): done[1:],
# ("next", "success"): success[1:],
},
batch_size=ep_num_frames - 1,
)
ep_dict = {
"observation.state": state,
"action": action,
"episode": torch.tensor([ep_id] * num_frames),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.state": state,
# TODO(rcadene): compute reward and success
# "next.reward": reward[1:],
"next.done": done[1:],
# "next.success": success[1:],
}
for cam in CAMERAS[self.dataset_id]:
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
ep_td["observation", "image", cam] = image[:-1]
ep_td["next", "observation", "image", cam] = image[1:]
ep_dict[f"observation.images.{cam}"] = image[:-1]
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
if ep_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
ep_dicts.append(ep_dict)
td_data[idxtd : idxtd + len(ep_td)] = ep_td
idxtd = idxtd + len(ep_td)
self.data_dict = {}
return TensorStorage(td_data.lock_())
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)