WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -1,75 +1,104 @@
import pickle
import zipfile
from pathlib import Path
from typing import Callable
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.replay_buffers.samplers import (
Sampler,
)
from torchrl.data.replay_buffers.storages import TensorStorage
from torchrl.data.replay_buffers.writers import Writer
from lerobot.common.datasets.abstract import AbstractDataset
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
def download():
raise NotImplementedError()
def download(raw_dir):
import gdown
raw_dir.mkdir(parents=True, exist_ok=True)
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
download_path = "data.zip"
gdown.download(url, download_path, quiet=False)
zip_path = raw_dir / "data.zip"
gdown.download(url, str(zip_path), quiet=False)
print("Extracting...")
with zipfile.ZipFile(download_path, "r") as zip_f:
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
for member in zip_f.namelist():
if member.startswith("data/xarm") and member.endswith(".pkl"):
print(member)
zip_f.extract(member=member)
Path(download_path).unlink()
zip_path.unlink()
class SimxarmDataset(AbstractDataset):
class SimxarmDataset(torch.utils.data.Dataset):
available_datasets = [
"xarm_lift_medium",
]
fps = 15
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
batch_size: int | None = None,
*,
shuffle: bool = True,
root: Path | None = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: Sampler | None = None,
collate_fn: Callable | None = None,
writer: Writer | None = None,
transform: "torchrl.envs.Transform" = None,
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__(
dataset_id,
version,
batch_size,
shuffle=shuffle,
root=root,
pin_memory=pin_memory,
prefetch=prefetch,
sampler=sampler,
collate_fn=collate_fn,
writer=writer,
transform=transform,
)
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.transform = transform
self.delta_timestamps = delta_timestamps
data_dir = self.root / f"{self.dataset_id}"
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
self.data_dict = torch.load(data_dir / "data_dict.pth")
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
else:
self._download_and_preproc_obsolete()
data_dir.mkdir(parents=True, exist_ok=True)
torch.save(self.data_dict, data_dir / "data_dict.pth")
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
@property
def num_samples(self) -> int:
return len(self.data_dict["index"])
@property
def num_episodes(self) -> int:
return len(self.data_ids_per_episode)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = {}
# get episode id and timestamp of the sampled frame
current_ts = self.data_dict["timestamp"][idx].item()
episode = self.data_dict["episode"][idx].item()
for key in self.data_dict:
if self.delta_timestamps is not None and key in self.delta_timestamps:
data, is_pad = load_data_with_delta_timestamps(
self.data_dict,
self.data_ids_per_episode,
self.delta_timestamps,
key,
current_ts,
episode,
)
item[key] = data
item[f"{key}_is_pad"] = is_pad
else:
item[key] = self.data_dict[key][idx]
if self.transform is not None:
item = self.transform(item)
return item
def _download_and_preproc_obsolete(self):
# assert self.root is not None
# TODO(rcadene): finish download
# download()
assert self.root is not None
raw_dir = self.root / f"{self.dataset_id}_raw"
if not raw_dir.exists():
download(raw_dir)
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
@@ -78,6 +107,9 @@ class SimxarmDataset(AbstractDataset):
total_frames = dataset_dict["actions"].shape[0]
self.data_ids_per_episode = {}
ep_dicts = []
idx0 = 0
idx1 = 0
episode_id = 0
@@ -91,37 +123,38 @@ class SimxarmDataset(AbstractDataset):
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
# TODO(rcadene): concat the last "next_observations" to "observations"
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "reward"): next_reward,
("next", "done"): next_done,
},
batch_size=num_frames,
)
ep_dict = {
"observation.image": image,
"observation.state": state,
"action": action,
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
# "next.observation.image": next_image,
# "next.observation.state": next_state,
"next.reward": next_reward,
"next.done": next_done,
}
ep_dicts.append(ep_dict)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
)
assert isinstance(episode_id, int)
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
assert len(self.data_ids_per_episode[episode_id]) == num_frames
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
episode_id += 1
return TensorStorage(td_data.lock_())
self.data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
self.data_dict["index"] = torch.arange(0, total_frames, 1)