WIP
WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
@@ -1,26 +1,13 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import gdown
|
||||
import h5py
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
|
||||
DATASET_IDS = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
FOLDER_URLS = {
|
||||
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
|
||||
@@ -66,7 +53,6 @@ CAMERAS = {
|
||||
|
||||
|
||||
def download(data_dir, dataset_id):
|
||||
assert dataset_id in DATASET_IDS
|
||||
assert dataset_id in FOLDER_URLS
|
||||
assert dataset_id in EP48_URLS
|
||||
assert dataset_id in EP49_URLS
|
||||
@@ -80,51 +66,78 @@ def download(data_dir, dataset_id):
|
||||
gdown.download(EP49_URLS[dataset_id], output=str(data_dir / "episode_49.hdf5"), fuzzy=True)
|
||||
|
||||
|
||||
class AlohaDataset(AbstractDataset):
|
||||
available_datasets = DATASET_IDS
|
||||
class AlohaDataset(torch.utils.data.Dataset):
|
||||
available_datasets = [
|
||||
"aloha_sim_insertion_human",
|
||||
"aloha_sim_insertion_scripted",
|
||||
"aloha_sim_transfer_cube_human",
|
||||
"aloha_sim_transfer_cube_scripted",
|
||||
]
|
||||
fps = 50
|
||||
image_keys = ["observation.images.top"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
("observation", "state"): "b c -> c",
|
||||
("action",): "b c -> c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> c 1 1"
|
||||
return d
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list:
|
||||
return [("observation", "image", cam) for cam in CAMERAS[self.dataset_id]]
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
@@ -132,54 +145,55 @@ class AlohaDataset(AbstractDataset):
|
||||
if not raw_dir.is_dir():
|
||||
download(raw_dir, self.dataset_id)
|
||||
|
||||
total_num_frames = 0
|
||||
total_frames = 0
|
||||
logging.info("Compute total number of frames to initialize offline buffer")
|
||||
for ep_id in range(NUM_EPISODES[self.dataset_id]):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
total_num_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_num_frames=}")
|
||||
total_frames += ep["/action"].shape[0] - 1
|
||||
logging.info(f"{total_frames=}")
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
logging.info("Initialize and feed offline buffer")
|
||||
idxtd = 0
|
||||
for ep_id in tqdm.tqdm(range(NUM_EPISODES[self.dataset_id])):
|
||||
ep_path = raw_dir / f"episode_{ep_id}.hdf5"
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
ep_num_frames = ep["/action"].shape[0]
|
||||
num_frames = ep["/action"].shape[0]
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(ep_num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": torch.tensor([ep_id] * (ep_num_frames - 1)),
|
||||
"frame_id": torch.arange(0, ep_num_frames - 1, 1),
|
||||
("next", "observation", "state"): state[1:],
|
||||
# TODO: compute reward and success
|
||||
# ("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
# ("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=ep_num_frames - 1,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([ep_id] * num_frames),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.state": state,
|
||||
# TODO(rcadene): compute reward and success
|
||||
# "next.reward": reward[1:],
|
||||
"next.done": done[1:],
|
||||
# "next.success": success[1:],
|
||||
}
|
||||
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:])
|
||||
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
||||
ep_td["observation", "image", cam] = image[:-1]
|
||||
ep_td["next", "observation", "image", cam] = image[1:]
|
||||
ep_dict[f"observation.images.{cam}"] = image[:-1]
|
||||
# ep_dict[f"next.observation.images.{cam}"] = image[1:]
|
||||
|
||||
if ep_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_num_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
idxtd = idxtd + len(ep_td)
|
||||
self.data_dict = {}
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler, SliceSampler
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
from lerobot.common.transforms import Prod
|
||||
|
||||
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
||||
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
||||
@@ -13,57 +12,12 @@ from lerobot.common.transforms import NormalizeTransform, Prod
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_offline_buffer(
|
||||
def make_dataset(
|
||||
cfg,
|
||||
overwrite_sampler=None,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
overwrite_batch_size=None,
|
||||
overwrite_prefetch=None,
|
||||
stats_path=None,
|
||||
):
|
||||
if cfg.policy.balanced_sampling:
|
||||
assert cfg.online_steps > 0
|
||||
batch_size = None
|
||||
pin_memory = False
|
||||
prefetch = None
|
||||
else:
|
||||
assert cfg.online_steps == 0
|
||||
num_slices = cfg.policy.batch_size
|
||||
batch_size = cfg.policy.horizon * num_slices
|
||||
pin_memory = cfg.device == "cuda"
|
||||
prefetch = cfg.prefetch
|
||||
|
||||
if overwrite_batch_size is not None:
|
||||
batch_size = overwrite_batch_size
|
||||
|
||||
if overwrite_prefetch is not None:
|
||||
prefetch = overwrite_prefetch
|
||||
|
||||
if overwrite_sampler is None:
|
||||
# TODO(rcadene): move batch_size outside
|
||||
num_traj_per_batch = cfg.policy.batch_size # // cfg.horizon
|
||||
# TODO(rcadene): Sampler outputs a batch_size <= cfg.batch_size.
|
||||
# We would need to add a transform to pad the tensordict to ensure batch_size == cfg.batch_size.
|
||||
|
||||
if cfg.offline_prioritized_sampler:
|
||||
logging.info("use prioritized sampler for offline dataset")
|
||||
sampler = PrioritizedSliceSampler(
|
||||
max_capacity=100_000,
|
||||
alpha=cfg.policy.per_alpha,
|
||||
beta=cfg.policy.per_beta,
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
logging.info("use simple sampler for offline dataset")
|
||||
sampler = SliceSampler(
|
||||
num_slices=num_traj_per_batch,
|
||||
strict_length=False,
|
||||
)
|
||||
else:
|
||||
sampler = overwrite_sampler
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.datasets.simxarm import SimxarmDataset
|
||||
|
||||
@@ -81,56 +35,56 @@ def make_offline_buffer(
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
offline_buffer = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
root=DATA_DIR,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
|
||||
transforms = None
|
||||
if normalize:
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for
|
||||
# now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
|
||||
in_keys += img_keys
|
||||
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
|
||||
in_keys += [("next", *key) for key in img_keys]
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
stats = {}
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
stats["observation.state"] = {}
|
||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
|
||||
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
offline_buffer.set_transform(transforms)
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
# TODO(rcadene): we need to do something about image_keys
|
||||
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
# NormalizeTransform(
|
||||
# stats,
|
||||
# in_keys=[
|
||||
# "observation.state",
|
||||
# "action",
|
||||
# ],
|
||||
# mode=normalization_mode,
|
||||
# ),
|
||||
]
|
||||
)
|
||||
|
||||
if not overwrite_sampler:
|
||||
index = torch.arange(0, offline_buffer.num_samples, 1)
|
||||
sampler.extend(index)
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): implement delta_timestamps in config
|
||||
delta_timestamps = {
|
||||
"observation.image": [-0.1, 0],
|
||||
"observation.state": [-0.1, 0],
|
||||
"action": [-0.1] + [i / clsfunc.fps for i in range(15)],
|
||||
}
|
||||
else:
|
||||
delta_timestamps = None
|
||||
|
||||
return offline_buffer
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pygame
|
||||
import pymunk
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import Sampler
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip
|
||||
from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps
|
||||
from lerobot.common.envs.pusht.pusht_env import pymunk_to_shapely
|
||||
from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer
|
||||
|
||||
@@ -83,37 +76,82 @@ def add_tee(
|
||||
return body
|
||||
|
||||
|
||||
class PushtDataset(AbstractDataset):
|
||||
class PushtDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
||||
Arguments
|
||||
----------
|
||||
delta_timestamps : dict[list[float]] | None, optional
|
||||
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
|
||||
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
||||
"""
|
||||
|
||||
available_datasets = ["pusht"]
|
||||
fps = 10
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.2",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
assert self.root is not None
|
||||
@@ -147,8 +185,10 @@ class PushtDataset(AbstractDataset):
|
||||
states = torch.from_numpy(dataset_dict["state"])
|
||||
actions = torch.from_numpy(dataset_dict["action"])
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idxtd = 0
|
||||
for episode_id in tqdm.tqdm(range(num_episodes)):
|
||||
idx1 = dataset_dict.meta["episode_ends"][episode_id]
|
||||
# to create test artifact
|
||||
@@ -194,30 +234,45 @@ class PushtDataset(AbstractDataset):
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "image"): image[:-1],
|
||||
("observation", "state"): agent_pos[:-1],
|
||||
"action": actions[idx0:idx1][:-1],
|
||||
"episode": episode_ids[idx0:idx1][:-1],
|
||||
"frame_id": torch.arange(0, num_frames - 1, 1),
|
||||
("next", "observation", "image"): image[1:],
|
||||
("next", "observation", "state"): agent_pos[1:],
|
||||
# TODO: verify that reward and done are aligned with image and agent_pos
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": agent_pos,
|
||||
"action": actions[idx0:idx1],
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": image[1:],
|
||||
# "next.observation.state": agent_pos[1:],
|
||||
# TODO(rcadene): verify that reward and done are aligned with image and agent_pos
|
||||
"next.reward": torch.cat([reward[1:], reward[[-1]]]),
|
||||
"next.done": torch.cat([done[1:], done[[-1]]]),
|
||||
"next.success": torch.cat([success[1:], success[[-1]]]),
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.root / f"{self.dataset_id}")
|
||||
|
||||
td_data[idxtd : idxtd + len(ep_td)] = ep_td
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
idx0 = idx1
|
||||
idxtd = idxtd + len(ep_td)
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = PushtDataset(
|
||||
"pusht",
|
||||
root=Path("data"),
|
||||
delta_timestamps={
|
||||
"observation.image": [0, -1, -0.2, -0.1],
|
||||
"observation.state": [0, -1, -0.2, -0.1],
|
||||
"action": [-0.1, 0, 1, 2, 3],
|
||||
},
|
||||
)
|
||||
dataset[10]
|
||||
|
||||
@@ -1,75 +1,104 @@
|
||||
import pickle
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torchrl
|
||||
import tqdm
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.replay_buffers.samplers import (
|
||||
Sampler,
|
||||
)
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage
|
||||
from torchrl.data.replay_buffers.writers import Writer
|
||||
|
||||
from lerobot.common.datasets.abstract import AbstractDataset
|
||||
from lerobot.common.datasets.utils import load_data_with_delta_timestamps
|
||||
|
||||
|
||||
def download():
|
||||
raise NotImplementedError()
|
||||
def download(raw_dir):
|
||||
import gdown
|
||||
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
|
||||
download_path = "data.zip"
|
||||
gdown.download(url, download_path, quiet=False)
|
||||
zip_path = raw_dir / "data.zip"
|
||||
gdown.download(url, str(zip_path), quiet=False)
|
||||
print("Extracting...")
|
||||
with zipfile.ZipFile(download_path, "r") as zip_f:
|
||||
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
|
||||
for member in zip_f.namelist():
|
||||
if member.startswith("data/xarm") and member.endswith(".pkl"):
|
||||
print(member)
|
||||
zip_f.extract(member=member)
|
||||
Path(download_path).unlink()
|
||||
zip_path.unlink()
|
||||
|
||||
|
||||
class SimxarmDataset(AbstractDataset):
|
||||
class SimxarmDataset(torch.utils.data.Dataset):
|
||||
available_datasets = [
|
||||
"xarm_lift_medium",
|
||||
]
|
||||
fps = 15
|
||||
image_keys = ["observation.image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_id: str,
|
||||
version: str | None = "v1.1",
|
||||
batch_size: int | None = None,
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
root: Path | None = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: Sampler | None = None,
|
||||
collate_fn: Callable | None = None,
|
||||
writer: Writer | None = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
dataset_id,
|
||||
version,
|
||||
batch_size,
|
||||
shuffle=shuffle,
|
||||
root=root,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
writer=writer,
|
||||
transform=transform,
|
||||
)
|
||||
super().__init__()
|
||||
self.dataset_id = dataset_id
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
data_dir = self.root / f"{self.dataset_id}"
|
||||
if (data_dir / "data_dict.pth").exists() and (data_dir / "data_ids_per_episode.pth").exists():
|
||||
self.data_dict = torch.load(data_dir / "data_dict.pth")
|
||||
self.data_ids_per_episode = torch.load(data_dir / "data_ids_per_episode.pth")
|
||||
else:
|
||||
self._download_and_preproc_obsolete()
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(self.data_dict, data_dir / "data_dict.pth")
|
||||
torch.save(self.data_ids_per_episode, data_dir / "data_ids_per_episode.pth")
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.data_dict["index"])
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
return len(self.data_ids_per_episode)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = {}
|
||||
|
||||
# get episode id and timestamp of the sampled frame
|
||||
current_ts = self.data_dict["timestamp"][idx].item()
|
||||
episode = self.data_dict["episode"][idx].item()
|
||||
|
||||
for key in self.data_dict:
|
||||
if self.delta_timestamps is not None and key in self.delta_timestamps:
|
||||
data, is_pad = load_data_with_delta_timestamps(
|
||||
self.data_dict,
|
||||
self.data_ids_per_episode,
|
||||
self.delta_timestamps,
|
||||
key,
|
||||
current_ts,
|
||||
episode,
|
||||
)
|
||||
item[key] = data
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
else:
|
||||
item[key] = self.data_dict[key][idx]
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
def _download_and_preproc_obsolete(self):
|
||||
# assert self.root is not None
|
||||
# TODO(rcadene): finish download
|
||||
# download()
|
||||
assert self.root is not None
|
||||
raw_dir = self.root / f"{self.dataset_id}_raw"
|
||||
if not raw_dir.exists():
|
||||
download(raw_dir)
|
||||
|
||||
dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl"
|
||||
print(f"Using offline dataset '{dataset_path}'")
|
||||
@@ -78,6 +107,9 @@ class SimxarmDataset(AbstractDataset):
|
||||
|
||||
total_frames = dataset_dict["actions"].shape[0]
|
||||
|
||||
self.data_ids_per_episode = {}
|
||||
ep_dicts = []
|
||||
|
||||
idx0 = 0
|
||||
idx1 = 0
|
||||
episode_id = 0
|
||||
@@ -91,37 +123,38 @@ class SimxarmDataset(AbstractDataset):
|
||||
|
||||
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
|
||||
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
|
||||
next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
action = torch.tensor(dataset_dict["actions"][idx0:idx1])
|
||||
# TODO(rcadene): concat the last "next_observations" to "observations"
|
||||
# next_image = torch.tensor(dataset_dict["next_observations"]["rgb"][idx0:idx1])
|
||||
# next_state = torch.tensor(dataset_dict["next_observations"]["state"][idx0:idx1])
|
||||
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
|
||||
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
|
||||
|
||||
episode = TensorDict(
|
||||
{
|
||||
("observation", "image"): image,
|
||||
("observation", "state"): state,
|
||||
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
("next", "observation", "image"): next_image,
|
||||
("next", "observation", "state"): next_state,
|
||||
("next", "reward"): next_reward,
|
||||
("next", "done"): next_done,
|
||||
},
|
||||
batch_size=num_frames,
|
||||
)
|
||||
ep_dict = {
|
||||
"observation.image": image,
|
||||
"observation.state": state,
|
||||
"action": action,
|
||||
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
|
||||
"frame_id": torch.arange(0, num_frames, 1),
|
||||
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
|
||||
# "next.observation.image": next_image,
|
||||
# "next.observation.state": next_state,
|
||||
"next.reward": next_reward,
|
||||
"next.done": next_done,
|
||||
}
|
||||
ep_dicts.append(ep_dict)
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = (
|
||||
episode[0]
|
||||
.expand(total_frames)
|
||||
.memmap_like(self.root / f"{self.dataset_id}" / "replay_buffer")
|
||||
)
|
||||
assert isinstance(episode_id, int)
|
||||
self.data_ids_per_episode[episode_id] = torch.arange(idx0, idx1, 1)
|
||||
assert len(self.data_ids_per_episode[episode_id]) == num_frames
|
||||
|
||||
td_data[idx0:idx1] = episode
|
||||
|
||||
episode_id += 1
|
||||
idx0 = idx1
|
||||
episode_id += 1
|
||||
|
||||
return TensorStorage(td_data.lock_())
|
||||
self.data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
self.data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
self.data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -3,6 +3,7 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@@ -28,3 +29,71 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def euclidean_distance_matrix(mat0, mat1):
|
||||
# Compute the square of the distance matrix
|
||||
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
||||
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
||||
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
||||
|
||||
# Taking the square root to get the euclidean distance
|
||||
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
||||
return distance
|
||||
|
||||
|
||||
def is_contiguously_true_or_false(bool_vector):
|
||||
assert bool_vector.ndim == 1
|
||||
assert bool_vector.dtype == torch.bool
|
||||
|
||||
# Compare each element with its neighbor to find changes
|
||||
changes = bool_vector[1:] != bool_vector[:-1]
|
||||
|
||||
# Count the number of changes
|
||||
num_changes = changes.sum().item()
|
||||
|
||||
# If there's more than one change, the list is not contiguous
|
||||
return num_changes <= 1
|
||||
|
||||
# examples = [
|
||||
# ([True, False, True, False, False, False], False),
|
||||
# ([True, True, True, False, False, False], True),
|
||||
# ([False, False, False, False, False, False], True)
|
||||
# ]
|
||||
# for bool_list, expected in examples:
|
||||
# result = is_contiguously_true_or_false(bool_list)
|
||||
|
||||
|
||||
def load_data_with_delta_timestamps(
|
||||
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
||||
):
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_ids = data_ids_per_episode[episode]
|
||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# get the indices of the data that are closest to the query timestamps
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
# closest_ts = ep_timestamps[argmin_]
|
||||
|
||||
# get the data
|
||||
data = data_dict[key][data_ids].clone()
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
tol = 0.02
|
||||
is_pad = min_ > tol
|
||||
|
||||
assert is_contiguously_true_or_false(is_pad), (
|
||||
"One or several timestamps unexpectedly violate the tolerance."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
return data, is_pad
|
||||
|
||||
@@ -1,64 +1,40 @@
|
||||
from torchrl.envs import SerialEnv
|
||||
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||
"""
|
||||
Note: The returned environment is wrapped in a torchrl.SerialEnv with cfg.rollout_batch_size underlying
|
||||
environments. The env therefore returns batches.`
|
||||
Note: When `num_parallel_envs > 0`, this function returns a `SyncVectorEnv` which takes batched action as input and
|
||||
returns batched observation, reward, terminated, truncated of `num_parallel_envs` items.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"frame_skip": cfg.env.action_repeat,
|
||||
"from_pixels": cfg.env.from_pixels,
|
||||
"pixels_only": cfg.env.pixels_only,
|
||||
"image_size": cfg.env.image_size,
|
||||
"num_prev_obs": cfg.n_obs_steps - 1,
|
||||
}
|
||||
kwargs = {}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
from lerobot.common.envs.simxarm.env import SimxarmEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = SimxarmEnv
|
||||
elif cfg.env.name == "pusht":
|
||||
from lerobot.common.envs.pusht.env import PushtEnv
|
||||
import gym_pusht # noqa
|
||||
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
|
||||
clsfunc = PushtEnv
|
||||
kwargs.update(
|
||||
{
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_action": False,
|
||||
}
|
||||
)
|
||||
env_fn = lambda: gym.make( # noqa: E731
|
||||
"gym_pusht/PushTPixels-v0",
|
||||
render_mode="rgb_array",
|
||||
max_episode_steps=cfg.env.episode_length,
|
||||
**kwargs,
|
||||
)
|
||||
elif cfg.env.name == "aloha":
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
|
||||
kwargs["task"] = cfg.env.task
|
||||
clsfunc = AlohaEnv
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
def _make_env(seed):
|
||||
nonlocal kwargs
|
||||
kwargs["seed"] = seed
|
||||
env = clsfunc(**kwargs)
|
||||
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||
|
||||
if transform is not None:
|
||||
# useful to add normalization
|
||||
if isinstance(transform, Compose):
|
||||
for tf in transform:
|
||||
env.append_transform(tf.clone())
|
||||
elif isinstance(transform, Transform):
|
||||
env.append_transform(transform.clone())
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return env
|
||||
|
||||
return SerialEnv(
|
||||
cfg.rollout_batch_size,
|
||||
create_env_fn=_make_env,
|
||||
create_env_kwargs=[
|
||||
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
if num_parallel_envs == 0:
|
||||
# non-batched version of the env that returns an observation of shape (c)
|
||||
env = env_fn()
|
||||
else:
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)])
|
||||
return env
|
||||
|
||||
@@ -55,7 +55,7 @@ class SimxarmEnv(AbstractEnv):
|
||||
if not _has_gym:
|
||||
raise ImportError("Cannot import gymnasium.")
|
||||
|
||||
import gymnasium
|
||||
import gymnasium as gym
|
||||
|
||||
from lerobot.common.envs.simxarm.simxarm import TASKS
|
||||
|
||||
@@ -65,7 +65,7 @@ class SimxarmEnv(AbstractEnv):
|
||||
self._env = TASKS[self.task]["env"]()
|
||||
|
||||
num_actions = len(TASKS[self.task]["action_space"])
|
||||
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
|
||||
self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32)
|
||||
if "w" not in TASKS[self.task]["action_space"]:
|
||||
self._action_padding[-1] = 1.0
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
|
||||
from lerobot.common.policies.utils import populate_queues
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
class DiffusionPolicy(nn.Module):
|
||||
name = "diffusion"
|
||||
|
||||
def __init__(
|
||||
@@ -38,8 +40,12 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
# parameters passed to step
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(n_action_steps)
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.n_action_steps = n_action_steps
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
|
||||
rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape)
|
||||
@@ -100,76 +106,58 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.n_obs_steps),
|
||||
"action": deque(maxlen=self.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
def select_action(self, batch, step):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
"""
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
# TODO(rcadene): remove unused step
|
||||
del step
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
assert len(batch) == 2
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
action = out["action"]
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
self._queues["action"].extend(out["action"].transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
def forward(self, batch, step):
|
||||
start_time = time.time()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
|
||||
assert batch_size % self.cfg.horizon == 0
|
||||
assert batch_size % num_slices == 0
|
||||
|
||||
def process_batch(batch, horizon, num_slices):
|
||||
# trajectory t = 64, horizon h = 16
|
||||
# (t h) ... -> t h ...
|
||||
batch = batch.reshape(num_slices, horizon) # .transpose(1, 0).contiguous()
|
||||
|
||||
# |-1|0|1|2|3|4|5|6|7|8|9|10|11|12|13|14| timestamps: 16
|
||||
# |o|o| observations: 2
|
||||
# | |a|a|a|a|a|a|a|a| actions executed: 8
|
||||
# |p|p|p|p|p|p|p|p|p|p|p| p| p| p| p| p| actions predicted: 16
|
||||
# note: we predict the action needed to go from t=-1 to t=0 similarly to an inverse kinematic model
|
||||
|
||||
image = batch["observation", "image"]
|
||||
state = batch["observation", "state"]
|
||||
action = batch["action"]
|
||||
assert image.shape[1] == horizon
|
||||
assert state.shape[1] == horizon
|
||||
assert action.shape[1] == horizon
|
||||
|
||||
if not (horizon == 16 and self.cfg.n_obs_steps == 2):
|
||||
raise NotImplementedError()
|
||||
|
||||
# keep first 2 observations of the slice corresponding to t=[-1,0]
|
||||
image = image[:, : self.cfg.n_obs_steps]
|
||||
state = state[:, : self.cfg.n_obs_steps]
|
||||
|
||||
out = {
|
||||
"obs": {
|
||||
"image": image.to(self.device, non_blocking=True),
|
||||
"agent_pos": state.to(self.device, non_blocking=True),
|
||||
},
|
||||
"action": action.to(self.device, non_blocking=True),
|
||||
}
|
||||
return out
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
obs_dict = {
|
||||
"image": batch["observation.image"],
|
||||
"agent_pos": batch["observation.state"],
|
||||
}
|
||||
loss = self.diffusion.compute_loss(obs_dict)
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
|
||||
@@ -1,53 +1,49 @@
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchrl.envs.transforms import ObservationTransform, Transform
|
||||
from torchvision.transforms.v2 import Compose, Transform
|
||||
|
||||
|
||||
class Prod(ObservationTransform):
|
||||
def apply_inverse_transform(item, transform):
|
||||
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
||||
for tf in transforms[::-1]:
|
||||
if tf.invertible:
|
||||
item = tf.inverse_transform(item)
|
||||
else:
|
||||
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
||||
return item
|
||||
|
||||
|
||||
class Prod(Transform):
|
||||
invertible = True
|
||||
|
||||
def __init__(self, in_keys: Sequence[NestedKey], prod: float):
|
||||
def __init__(self, in_keys: list[str], prod: float):
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.prod = prod
|
||||
self.original_dtypes = {}
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td):
|
||||
def forward(self, item):
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
if key not in item:
|
||||
continue
|
||||
self.original_dtypes[key] = td[key].dtype
|
||||
td[key] = td[key].type(torch.float32) * self.prod
|
||||
return td
|
||||
self.original_dtypes[key] = item[key].dtype
|
||||
item[key] = item[key].type(torch.float32) * self.prod
|
||||
return item
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
def inverse_transform(self, item):
|
||||
for key in self.in_keys:
|
||||
if td.get(key, None) is None:
|
||||
if key not in item:
|
||||
continue
|
||||
td[key] = (td[key] / self.prod).type(self.original_dtypes[key])
|
||||
return td
|
||||
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
||||
return item
|
||||
|
||||
def transform_observation_spec(self, obs_spec):
|
||||
for key in self.in_keys:
|
||||
if obs_spec.get(key, None) is None:
|
||||
continue
|
||||
obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
obs_spec[key].dtype = torch.float32
|
||||
return obs_spec
|
||||
# def transform_observation_spec(self, obs_spec):
|
||||
# for key in self.in_keys:
|
||||
# if obs_spec.get(key, None) is None:
|
||||
# continue
|
||||
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
||||
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
||||
# obs_spec[key].dtype = torch.float32
|
||||
# return obs_spec
|
||||
|
||||
|
||||
class NormalizeTransform(Transform):
|
||||
@@ -55,65 +51,50 @@ class NormalizeTransform(Transform):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stats: TensorDictBase,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] | None = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
stats: dict,
|
||||
in_keys: list[str] = None,
|
||||
out_keys: list[str] | None = None,
|
||||
in_keys_inv: list[str] | None = None,
|
||||
out_keys_inv: list[str] | None = None,
|
||||
mode="mean_std",
|
||||
):
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
super().__init__()
|
||||
self.in_keys = in_keys
|
||||
self.out_keys = in_keys if out_keys is None else out_keys
|
||||
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
||||
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
||||
self.stats = stats
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
self.mode = mode
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
def forward(self, item):
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
# normalize to [0,1]
|
||||
td[outkey] = (td[inkey] - min) / (max - min)
|
||||
item[outkey] = (item[inkey] - min) / (max - min)
|
||||
# normalize to [-1, 1]
|
||||
td[outkey] = td[outkey] * 2 - 1
|
||||
return td
|
||||
item[outkey] = item[outkey] * 2 - 1
|
||||
return item
|
||||
|
||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
def inverse_transform(self, item):
|
||||
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
if inkey not in item:
|
||||
continue
|
||||
if self.mode == "mean_std":
|
||||
mean = self.stats[inkey]["mean"]
|
||||
std = self.stats[inkey]["std"]
|
||||
td[outkey] = td[inkey] * std + mean
|
||||
item[outkey] = item[inkey] * std + mean
|
||||
else:
|
||||
min = self.stats[inkey]["min"]
|
||||
max = self.stats[inkey]["max"]
|
||||
td[outkey] = (td[inkey] + 1) / 2
|
||||
td[outkey] = td[outkey] * (max - min) + min
|
||||
return td
|
||||
item[outkey] = (item[inkey] + 1) / 2
|
||||
item[outkey] = item[outkey] * (max - min) + min
|
||||
return item
|
||||
|
||||
Reference in New Issue
Block a user