Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)

This commit is contained in:
Cadene
2024-02-26 01:10:09 +00:00
parent 5a219fed6e
commit 21670dce90
12 changed files with 306 additions and 443 deletions

View File

@@ -4,6 +4,26 @@ from lerobot.common.datasets.pusht import PushtExperienceReplay
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from rl.torchrl.data.replay_buffers.samplers import PrioritizedSliceSampler
# TODO(rcadene): implement
# dataset_d4rl = D4RLExperienceReplay(
# dataset_id="maze2d-umaze-v1",
# split_trajs=False,
# batch_size=1,
# sampler=SamplerWithoutReplacement(drop_last=False),
# prefetch=4,
# direct_download=True,
# )
# dataset_openx = OpenXExperienceReplay(
# "cmu_stretch",
# batch_size=1,
# num_slices=1,
# #download="force",
# streaming=False,
# root="data",
# )
def make_offline_buffer(cfg, sampler=None):

View File

@@ -10,10 +10,10 @@ from termcolor import colored
CONSOLE_FORMAT = [
("episode", "E", "int"),
("env_step", "S", "int"),
("step", "S", "int"),
("avg_sum_reward", "RS", "float"),
("avg_max_reward", "RM", "float"),
("pc_success", "S", "float"),
("pc_success", "SR", "float"),
("total_time", "T", "time"),
]
AGENT_METRICS = [
@@ -51,7 +51,9 @@ def print_run(cfg, reward=None):
kvs = [
("task", cfg.env.task),
("train steps", f"{int(cfg.train_steps * cfg.env.action_repeat):,}"),
("offline_steps", f"{cfg.offline_steps}"),
("online_steps", f"{cfg.online_steps}"),
("action_repeat", f"{cfg.env.action_repeat}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
@@ -78,54 +80,6 @@ def cfg_to_group(cfg, return_list=False):
return lst if return_list else "-".join(lst)
class VideoRecorder:
"""Utility class for logging evaluation videos."""
def __init__(self, root_dir, wandb, render_size=384, fps=15):
self.save_dir = (root_dir / "eval_video") if root_dir else None
self._wandb = wandb
self.render_size = render_size
self.fps = fps
self.frames = []
self.enabled = False
self.camera_id = 0
def init(self, env, enabled=True):
self.frames = []
self.enabled = self.save_dir and self._wandb and enabled
try:
env_name = env.unwrapped.spec.id
except:
env_name = ""
if "maze2d" in env_name:
self.camera_id = -1
elif "quadruped" in env_name:
self.camera_id = 2
self.record(env)
def record(self, env):
if self.enabled:
frame = env.render(
mode="rgb_array",
height=self.render_size,
width=self.render_size,
camera_id=self.camera_id,
)
self.frames.append(frame)
def save(self, step):
if self.enabled:
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
self._wandb.log(
{
"eval_video": self._wandb.Video(
frames, fps=self.env.fps, format="mp4"
)
},
step=step,
)
class Logger(object):
"""Primary logger object. Logs either locally or using wandb."""
@@ -170,15 +124,6 @@ class Logger(object):
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
self._wandb = wandb
self._video = (
VideoRecorder(self._log_dir, self._wandb)
if self._wandb and cfg.save_video
else None
)
@property
def video(self):
return self._video
def save_model(self, agent, identifier):
if self._save_model:
@@ -214,12 +159,12 @@ class Logger(object):
def _format(self, key, value, ty):
if ty == "int":
return f'{colored(key + ":", "grey")} {int(value):,}'
return f'{colored(key + ":", "yellow")} {int(value):,}'
elif ty == "float":
return f'{colored(key + ":", "grey")} {value:.01f}'
return f'{colored(key + ":", "yellow")} {value:.01f}'
elif ty == "time":
value = str(datetime.timedelta(seconds=int(value)))
return f'{colored(key + ":", "grey")} {value}'
return f'{colored(key + ":", "yellow")} {value}'
else:
raise f"invalid log format type: {ty}"
@@ -234,10 +179,9 @@ class Logger(object):
assert category in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
self._wandb.log({category + "/" + k: v}, step=d["step"])
if category == "eval":
# keys = ['env_step', 'avg_reward']
keys = ["env_step", "avg_sum_reward", "avg_max_reward", "pc_success"]
keys = ["step", "avg_sum_reward", "avg_max_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.log", header=keys, index=None

View File

@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy

View File

@@ -4,9 +4,29 @@ def make_policy(cfg):
policy = TDMPC(cfg.policy)
elif cfg.policy.name == "diffusion":
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import (
MultiImageObsEncoder,
)
from lerobot.common.policies.diffusion import DiffusionPolicy
policy = DiffusionPolicy(cfg.policy)
noise_scheduler = DDPMScheduler(**cfg.noise_scheduler)
rgb_model = get_resnet(**cfg.rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg.obs_encoder,
)
policy = DiffusionPolicy(
noise_scheduler=noise_scheduler,
obs_encoder=obs_encoder,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
else:
raise ValueError(cfg.policy.name)

View File

@@ -441,261 +441,6 @@ class Episode(object):
self._idx += 1
class ReplayBuffer:
"""
Storage and sampling functionality.
"""
def __init__(self, cfg, dataset=None):
action_dim = cfg.action_dim
obs_shape = {"rgb": (3, cfg.img_size, cfg.img_size), "state": (cfg.state_dim,)}
self.cfg = cfg
self.device = torch.device(cfg.buffer_device)
print("Replay buffer device: ", self.device)
if dataset is not None:
self.capacity = max(dataset["rewards"].shape[0], cfg.max_buffer_size)
else:
self.capacity = min(cfg.train_steps, cfg.max_buffer_size)
if cfg.modality in {"pixels", "state"}:
dtype = torch.float32 if cfg.modality == "state" else torch.uint8
# Note self.obs_shape always has single frame, which is different from cfg.obs_shape
self.obs_shape = (
obs_shape if cfg.modality == "state" else (3, *obs_shape[-2:])
)
self._obs = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape),
dtype=dtype,
device=self.device,
)
self._next_obs = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape),
dtype=dtype,
device=self.device,
)
elif cfg.modality == "all":
self.obs_shape = {}
self._obs, self._next_obs = {}, {}
for k, v in obs_shape.items():
assert k in {"rgb", "state"}
dtype = torch.float32 if k == "state" else torch.uint8
self.obs_shape[k] = v if k == "state" else (3, *v[-2:])
self._obs[k] = torch.zeros(
(self.capacity + cfg.horizon - 1, *self.obs_shape[k]),
dtype=dtype,
device=self.device,
)
self._next_obs[k] = self._obs[k].clone()
else:
raise ValueError
self._action = torch.zeros(
(self.capacity + cfg.horizon - 1, action_dim),
dtype=torch.float32,
device=self.device,
)
self._reward = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._mask = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._done = torch.zeros(
(self.capacity + cfg.horizon - 1,), dtype=torch.bool, device=self.device
)
self._priorities = torch.ones(
(self.capacity + cfg.horizon - 1,), dtype=torch.float32, device=self.device
)
self._eps = 1e-6
self._full = False
self.idx = 0
if dataset is not None:
self.init_from_offline_dataset(dataset)
self._aug = aug(cfg)
def init_from_offline_dataset(self, dataset):
"""Initialize the replay buffer from an offline dataset."""
assert self.idx == 0 and not self._full
n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent)
def copy_data(dst, src, n):
assert isinstance(dst, dict) == isinstance(src, dict)
if isinstance(dst, dict):
for k in dst:
copy_data(dst[k], src[k], n)
else:
dst[:n] = torch.from_numpy(src[:n])
copy_data(self._obs, dataset["observations"], n_transitions)
copy_data(self._next_obs, dataset["next_observations"], n_transitions)
copy_data(self._action, dataset["actions"], n_transitions)
copy_data(self._reward, dataset["rewards"], n_transitions)
copy_data(self._mask, dataset["masks"], n_transitions)
copy_data(self._done, dataset["dones"], n_transitions)
self.idx = (self.idx + n_transitions) % self.capacity
self._full = n_transitions >= self.capacity
def __add__(self, episode: Episode):
self.add(episode)
return self
def add(self, episode: Episode):
"""Add an episode to the replay buffer."""
if self.idx + len(episode) > self.capacity:
print("Warning: episode got truncated")
ep_len = min(len(episode), self.capacity - self.idx)
idxs = slice(self.idx, self.idx + ep_len)
assert self.idx + ep_len <= self.capacity
if self.cfg.modality in {"pixels", "state"}:
self._obs[idxs] = (
episode.obses[:ep_len]
if self.cfg.modality == "state"
else episode.obses[:ep_len, -3:]
)
self._next_obs[idxs] = (
episode.obses[1 : ep_len + 1]
if self.cfg.modality == "state"
else episode.obses[1 : ep_len + 1, -3:]
)
elif self.cfg.modality == "all":
for k, v in episode.obses.items():
assert k in {"rgb", "state"}
assert k in self._obs
assert k in self._next_obs
if k == "rgb":
self._obs[k][idxs] = episode.obses[k][:ep_len, -3:]
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1, -3:]
else:
self._obs[k][idxs] = episode.obses[k][:ep_len]
self._next_obs[k][idxs] = episode.obses[k][1 : ep_len + 1]
self._action[idxs] = episode.actions[:ep_len]
self._reward[idxs] = episode.rewards[:ep_len]
self._mask[idxs] = episode.masks[:ep_len]
self._done[idxs] = episode.dones[:ep_len]
self._done[self.idx + ep_len - 1] = True # in case truncated
if self._full:
max_priority = (
self._priorities[: self.capacity].max().to(self.device).item()
)
else:
max_priority = (
1.0
if self.idx == 0
else self._priorities[: self.idx].max().to(self.device).item()
)
new_priorities = torch.full((ep_len,), max_priority, device=self.device)
self._priorities[idxs] = new_priorities
self.idx = (self.idx + ep_len) % self.capacity
self._full = self._full or self.idx == 0
def update_priorities(self, idxs, priorities):
"""Update priorities for Prioritized Experience Replay (PER)"""
self._priorities[idxs] = priorities.squeeze(1).to(self.device) + self._eps
def _get_obs(self, arr, idxs):
"""Retrieve observations by indices"""
if isinstance(arr, dict):
return {k: self._get_obs(v, idxs) for k, v in arr.items()}
if arr.ndim <= 2: # if self.cfg.modality == 'state':
return arr[idxs].cuda()
obs = torch.empty(
(self.cfg.batch_size, 3 * self.cfg.frame_stack, *arr.shape[-2:]),
dtype=arr.dtype,
device=torch.device("cuda"),
)
obs[:, -3:] = arr[idxs].cuda()
_idxs = idxs.clone()
mask = torch.ones_like(_idxs, dtype=torch.bool)
for i in range(1, self.cfg.frame_stack):
mask[_idxs % self.cfg.episode_length == 0] = False
_idxs[mask] -= 1
obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda()
return obs.float()
def sample(self):
"""Sample transitions from the replay buffer."""
probs = (
self._priorities[: self.capacity]
if self._full
else self._priorities[: self.idx]
) ** self.cfg.per_alpha
probs /= probs.sum()
total = len(probs)
idxs = torch.from_numpy(
np.random.choice(
total,
self.cfg.batch_size,
p=probs.cpu().numpy(),
replace=not self._full,
)
).to(self.device)
weights = (total * probs[idxs]) ** (-self.cfg.per_beta)
weights /= weights.max()
idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)])
obs = self._aug(self._get_obs(self._obs, idxs))
next_obs = [
self._aug(self._get_obs(self._next_obs, _idxs)) for _idxs in idxs_in_horizon
]
if isinstance(next_obs[0], dict):
next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]}
else:
next_obs = torch.stack(next_obs)
action = self._action[idxs_in_horizon]
reward = self._reward[idxs_in_horizon]
mask = self._mask[idxs_in_horizon]
done = self._done[idxs_in_horizon]
if not action.is_cuda:
action, reward, mask, done, idxs, weights = (
action.cuda(),
reward.cuda(),
mask.cuda(),
done.cuda(),
idxs.cuda(),
weights.cuda(),
)
return (
obs,
next_obs,
action,
reward.unsqueeze(2),
mask.unsqueeze(2),
done.unsqueeze(2),
idxs,
weights,
)
def save(self, path):
"""Save the replay buffer to path"""
print(f"saving replay buffer to '{path}'...")
sz = self.capacity if self._full else self.idx
dataset = {
"observations": (
{k: v[:sz].cpu().numpy() for k, v in self._obs.items()}
if isinstance(self._obs, dict)
else self._obs[:sz].cpu().numpy()
),
"next_observations": (
{k: v[:sz].cpu().numpy() for k, v in self._next_obs.items()}
if isinstance(self._next_obs, dict)
else self._next_obs[:sz].cpu().numpy()
),
"actions": self._action[:sz].cpu().numpy(),
"rewards": self._reward[:sz].cpu().numpy(),
"dones": self._done[:sz].cpu().numpy(),
"masks": self._mask[:sz].cpu().numpy(),
}
with open(path, "wb") as f:
pickle.dump(dataset, f)
return dataset
def get_dataset_dict(cfg, env, return_reward_normalizer=False):
"""Construct a dataset for env"""
required_keys = [