Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,29 @@ def make_policy(cfg):
|
||||
|
||||
policy = TDMPC(cfg.policy)
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import (
|
||||
MultiImageObsEncoder,
|
||||
)
|
||||
|
||||
from lerobot.common.policies.diffusion import DiffusionPolicy
|
||||
|
||||
policy = DiffusionPolicy(cfg.policy)
|
||||
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
|
||||
|
||||
rgb_model = get_resnet(**cfg.rgb_model)
|
||||
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg.obs_encoder,
|
||||
)
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
noise_scheduler=noise_scheduler,
|
||||
obs_encoder=obs_encoder,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
||||
@@ -441,261 +441,6 @@ class Episode(object):
|
||||
self._idx += 1
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""
|
||||
Storage and sampling functionality.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, dataset=None):
|
||||
action_dim = cfg.action_dim
|
||||
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(cfg.buffer_device)
|
||||
print("Replay buffer device: ", self.device)
|
||||
|
||||
if dataset is not None:
|
||||
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
|
||||
else:
|
||||
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
|
||||
|
||||
if cfg.modality in {"pixels", "state"}:
|
||||
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
|
||||
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
|
||||
self.obs_shape = (
|
||||
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
|
||||
)
|
||||
self._obs = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self._next_obs = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, *self.obs_shape),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
elif cfg.modality == "all":
|
||||
self.obs_shape = {}
|
||||
self._obs, self._next_obs = {}, {}
|
||||
for k, v in obs_shape.items():
|
||||
assert k in {"rgb", "state"}
|
||||
dtype = torch.float32 if k == "state" else torch.uint8
|
||||
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
|
||||
self._obs[k] = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self._next_obs[k] = self._obs[k].clone()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self._action = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1, action_dim),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self._reward = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._mask = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._done = torch.zeros(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
self._priorities = torch.ones(
|
||||
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._eps = 1e-6
|
||||
self._full = False
|
||||
self.idx = 0
|
||||
if dataset is not None:
|
||||
self.init_from_offline_dataset(dataset)
|
||||
|
||||
self._aug = aug(cfg)
|
||||
|
||||
def init_from_offline_dataset(self, dataset):
|
||||
"""Initialize the replay buffer from an offline dataset."""
|
||||
assert self.idx == 0 and not self._full
|
||||
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
|
||||
|
||||
def copy_data(dst, src, n):
|
||||
assert isinstance(dst, dict) == isinstance(src, dict)
|
||||
if isinstance(dst, dict):
|
||||
for k in dst:
|
||||
copy_data(dst[k], src[k], n)
|
||||
else:
|
||||
dst[:n] = torch.from_numpy(src[:n])
|
||||
|
||||
copy_data(self._obs, dataset["observations"], n_transitions)
|
||||
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
|
||||
copy_data(self._action, dataset["actions"], n_transitions)
|
||||
copy_data(self._reward, dataset["rewards"], n_transitions)
|
||||
copy_data(self._mask, dataset["masks"], n_transitions)
|
||||
copy_data(self._done, dataset["dones"], n_transitions)
|
||||
self.idx = (self.idx + n_transitions) % self.capacity
|
||||
self._full = n_transitions >= self.capacity
|
||||
|
||||
def __add__(self, episode: Episode):
|
||||
self.add(episode)
|
||||
return self
|
||||
|
||||
def add(self, episode: Episode):
|
||||
"""Add an episode to the replay buffer."""
|
||||
if self.idx + len(episode) > self.capacity:
|
||||
print("Warning: episode got truncated")
|
||||
ep_len = min(len(episode), self.capacity - self.idx)
|
||||
idxs = slice(self.idx, self.idx + ep_len)
|
||||
assert self.idx + ep_len <= self.capacity
|
||||
if self.cfg.modality in {"pixels", "state"}:
|
||||
self._obs[idxs] = (
|
||||
episode.obses[:ep_len]
|
||||
if self.cfg.modality == "state"
|
||||
else episode.obses[:ep_len, -3:]
|
||||
)
|
||||
self._next_obs[idxs] = (
|
||||
episode.obses[1 : ep_len + 1]
|
||||
if self.cfg.modality == "state"
|
||||
else episode.obses[1 : ep_len + 1, -3:]
|
||||
)
|
||||
elif self.cfg.modality == "all":
|
||||
for k, v in episode.obses.items():
|
||||
assert k in {"rgb", "state"}
|
||||
assert k in self._obs
|
||||
assert k in self._next_obs
|
||||
if k == "rgb":
|
||||
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
|
||||
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
|
||||
else:
|
||||
self._obs[k][idxs] = episode.obses[k][:ep_len]
|
||||
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
|
||||
self._action[idxs] = episode.actions[:ep_len]
|
||||
self._reward[idxs] = episode.rewards[:ep_len]
|
||||
self._mask[idxs] = episode.masks[:ep_len]
|
||||
self._done[idxs] = episode.dones[:ep_len]
|
||||
self._done[self.idx + ep_len - 1] = True # in case truncated
|
||||
if self._full:
|
||||
max_priority = (
|
||||
self._priorities[: self.capacity].max().to(self.device).item()
|
||||
)
|
||||
else:
|
||||
max_priority = (
|
||||
1.0
|
||||
if self.idx == 0
|
||||
else self._priorities[: self.idx].max().to(self.device).item()
|
||||
)
|
||||
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
|
||||
self._priorities[idxs] = new_priorities
|
||||
self.idx = (self.idx + ep_len) % self.capacity
|
||||
self._full = self._full or self.idx == 0
|
||||
|
||||
def update_priorities(self, idxs, priorities):
|
||||
"""Update priorities for Prioritized Experience Replay (PER)"""
|
||||
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
|
||||
|
||||
def _get_obs(self, arr, idxs):
|
||||
"""Retrieve observations by indices"""
|
||||
if isinstance(arr, dict):
|
||||
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
|
||||
if arr.ndim <= 2: # if self.cfg.modality == 'state':
|
||||
return arr[idxs].cuda()
|
||||
obs = torch.empty(
|
||||
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
|
||||
dtype=arr.dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
obs[:, -3:] = arr[idxs].cuda()
|
||||
_idxs = idxs.clone()
|
||||
mask = torch.ones_like(_idxs, dtype=torch.bool)
|
||||
for i in range(1, self.cfg.frame_stack):
|
||||
mask[_idxs % self.cfg.episode_length == 0] = False
|
||||
_idxs[mask] -= 1
|
||||
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
|
||||
return obs.float()
|
||||
|
||||
def sample(self):
|
||||
"""Sample transitions from the replay buffer."""
|
||||
probs = (
|
||||
self._priorities[: self.capacity]
|
||||
if self._full
|
||||
else self._priorities[: self.idx]
|
||||
) ** self.cfg.per_alpha
|
||||
probs /= probs.sum()
|
||||
total = len(probs)
|
||||
idxs = torch.from_numpy(
|
||||
np.random.choice(
|
||||
total,
|
||||
self.cfg.batch_size,
|
||||
p=probs.cpu().numpy(),
|
||||
replace=not self._full,
|
||||
)
|
||||
).to(self.device)
|
||||
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
|
||||
weights /= weights.max()
|
||||
|
||||
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
|
||||
|
||||
obs = self._aug(self._get_obs(self._obs, idxs))
|
||||
next_obs = [
|
||||
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
|
||||
]
|
||||
if isinstance(next_obs[0], dict):
|
||||
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
|
||||
else:
|
||||
next_obs = torch.stack(next_obs)
|
||||
action = self._action[idxs_in_horizon]
|
||||
reward = self._reward[idxs_in_horizon]
|
||||
mask = self._mask[idxs_in_horizon]
|
||||
done = self._done[idxs_in_horizon]
|
||||
|
||||
if not action.is_cuda:
|
||||
action, reward, mask, done, idxs, weights = (
|
||||
action.cuda(),
|
||||
reward.cuda(),
|
||||
mask.cuda(),
|
||||
done.cuda(),
|
||||
idxs.cuda(),
|
||||
weights.cuda(),
|
||||
)
|
||||
|
||||
return (
|
||||
obs,
|
||||
next_obs,
|
||||
action,
|
||||
reward.unsqueeze(2),
|
||||
mask.unsqueeze(2),
|
||||
done.unsqueeze(2),
|
||||
idxs,
|
||||
weights,
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
"""Save the replay buffer to path"""
|
||||
print(f"saving replay buffer to '{path}'...")
|
||||
sz = self.capacity if self._full else self.idx
|
||||
dataset = {
|
||||
"observations": (
|
||||
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
|
||||
if isinstance(self._obs, dict)
|
||||
else self._obs[:sz].cpu().numpy()
|
||||
),
|
||||
"next_observations": (
|
||||
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
|
||||
if isinstance(self._next_obs, dict)
|
||||
else self._next_obs[:sz].cpu().numpy()
|
||||
),
|
||||
"actions": self._action[:sz].cpu().numpy(),
|
||||
"rewards": self._reward[:sz].cpu().numpy(),
|
||||
"dones": self._done[:sz].cpu().numpy(),
|
||||
"masks": self._mask[:sz].cpu().numpy(),
|
||||
}
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(dataset, f)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
|
||||
"""Construct a dataset for env"""
|
||||
required_keys = [
|
||||
|
||||
Reference in New Issue
Block a user