Eval reproduced! Train running (but not reproduced)

This commit is contained in:
Cadene
2024-02-10 15:46:24 +00:00
parent 937b2f8cba
commit 228c045674
14 changed files with 787 additions and 118 deletions

View File

View File

@@ -0,0 +1,190 @@
import os
import pickle
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import torch
import torchrl
import tqdm
from tensordict import TensorDict
from torchrl.data.datasets.utils import _get_root_dir
from torchrl.data.replay_buffers.replay_buffers import (
TensorDictPrioritizedReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers.samplers import (
Sampler,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
class SimxarmExperienceReplay(TensorDictReplayBuffer):
available_datasets = [
"xarm_lift_medium",
]
def __init__(
self,
dataset_id,
batch_size: int = None,
*,
shuffle: bool = True,
num_slices: int = None,
slice_len: int = None,
pad: float = None,
replacement: bool = None,
streaming: bool = False,
root: Path = None,
download: bool = False,
sampler: Sampler = None,
writer: Writer = None,
collate_fn: Callable = None,
pin_memory: bool = False,
prefetch: int = None,
transform: "torchrl.envs.Transform" = None, # noqa-F821
split_trajs: bool = False,
strict_length: bool = True,
):
self.download = download
if streaming:
raise NotImplementedError
self.streaming = streaming
self.dataset_id = dataset_id
self.split_trajs = split_trajs
self.shuffle = shuffle
self.num_slices = num_slices
self.slice_len = slice_len
self.pad = pad
self.strict_length = strict_length
if (self.num_slices is not None) and (self.slice_len is not None):
raise ValueError("num_slices or slice_len can be not None, but not both.")
if split_trajs:
raise NotImplementedError
if root is None:
root = _get_root_dir("simxarm")
os.makedirs(root, exist_ok=True)
self.root = Path(root)
if self.download == "force" or (self.download and not self._is_downloaded()):
storage = self._download_and_preproc()
else:
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
if num_slices is not None or slice_len is not None:
if sampler is not None:
raise ValueError(
"`num_slices` and `slice_len` are exclusive with the `sampler` argument."
)
if replacement:
if not self.shuffle:
raise RuntimeError(
"shuffle=False can only be used when replacement=False."
)
sampler = SliceSampler(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
)
else:
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
slice_len=slice_len,
strict_length=strict_length,
shuffle=self.shuffle,
)
if writer is None:
writer = ImmutableDatasetWriter()
if collate_fn is None:
collate_fn = _collate_id
super().__init__(
storage=storage,
sampler=sampler,
writer=writer,
collate_fn=collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
@property
def data_path_root(self):
if self.streaming:
return None
return self.root / self.dataset_id
def _is_downloaded(self):
return os.path.exists(self.data_path_root)
def _download_and_preproc(self):
# download
# TODO(rcadene)
# load
dataset_dir = Path("data") / self.dataset_id
dataset_path = dataset_dir / f"buffer.pkl"
print(f"Using offline dataset '{dataset_path}'")
with open(dataset_path, "rb") as f:
dataset_dict = pickle.load(f)
total_frames = dataset_dict["actions"].shape[0]
idx0 = 0
idx1 = 0
episode_id = 0
for i in tqdm.tqdm(range(total_frames)):
idx1 += 1
if not dataset_dict["dones"][i]:
continue
num_frames = idx1 - idx0
image = torch.tensor(dataset_dict["observations"]["rgb"][idx0:idx1])
state = torch.tensor(dataset_dict["observations"]["state"][idx0:idx1])
next_image = torch.tensor(
dataset_dict["next_observations"]["rgb"][idx0:idx1]
)
next_state = torch.tensor(
dataset_dict["next_observations"]["state"][idx0:idx1]
)
next_reward = torch.tensor(dataset_dict["rewards"][idx0:idx1])
next_done = torch.tensor(dataset_dict["dones"][idx0:idx1])
episode = TensorDict(
{
("observation", "image"): image,
("observation", "state"): state,
"action": torch.tensor(dataset_dict["actions"][idx0:idx1]),
"episode": torch.tensor([episode_id] * num_frames, dtype=torch.int),
"frame_id": torch.arange(0, num_frames, 1),
("next", "observation", "image"): next_image,
("next", "observation", "state"): next_state,
("next", "observation", "reward"): next_reward,
("next", "observation", "done"): next_done,
},
batch_size=num_frames,
)
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = (
episode[0]
.expand(total_frames)
.memmap_like(self.root / self.dataset_id)
)
td_data[idx0:idx1] = episode
episode_id += 1
idx0 = idx1
return TensorStorage(td_data.lock_())

View File

@@ -7,6 +7,7 @@ def make_env(cfg):
assert cfg.env == "simxarm"
env = SimxarmEnv(
task=cfg.task,
frame_skip=cfg.action_repeat,
from_pixels=cfg.from_pixels,
pixels_only=cfg.pixels_only,
image_size=cfg.image_size,

View File

@@ -24,6 +24,7 @@ class SimxarmEnv(EnvBase):
def __init__(
self,
task,
frame_skip: int = 1,
from_pixels: bool = False,
pixels_only: bool = False,
image_size=None,
@@ -32,6 +33,7 @@ class SimxarmEnv(EnvBase):
):
super().__init__(device=device, batch_size=[])
self.task = task
self.frame_skip = frame_skip
self.from_pixels = from_pixels
self.pixels_only = pixels_only
self.image_size = image_size
@@ -115,12 +117,15 @@ class SimxarmEnv(EnvBase):
# step expects shape=(4,) so we pad if necessary
action = np.concatenate([action, self._action_padding])
# TODO(rcadene): add info["is_success"] and info["success"] ?
raw_obs, reward, done, info = self._env.step(action)
sum_reward = 0
for t in range(self.frame_skip):
raw_obs, reward, done, info = self._env.step(action)
sum_reward += reward
td = TensorDict(
{
"observation": self._format_raw_obs(raw_obs),
"reward": torch.tensor([reward], dtype=torch.float32),
"reward": torch.tensor([sum_reward], dtype=torch.float32),
"done": torch.tensor([done], dtype=torch.bool),
"success": torch.tensor([info["success"]], dtype=torch.bool),
},

243
lerobot/common/logger.py Normal file
View File

@@ -0,0 +1,243 @@
import datetime
import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from termcolor import colored
CONSOLE_FORMAT = [
("episode", "E", "int"),
("env_step", "S", "int"),
("avg_reward", "R", "float"),
("pc_success", "R", "float"),
("total_time", "T", "time"),
]
AGENT_METRICS = [
"consistency_loss",
"reward_loss",
"value_loss",
"total_loss",
"weighted_loss",
"pi_loss",
"grad_norm",
]
def make_dir(dir_path):
"""Create directory if it does not already exist."""
try:
dir_path.mkdir(parents=True, exist_ok=True)
except OSError:
pass
return dir_path
def print_run(cfg, reward=None):
"""Pretty-printing of run information. Call at start of training."""
prefix, color, attrs = " ", "green", ["bold"]
def limstr(s, maxlen=32):
return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s
def pprint(k, v):
print(
prefix + colored(f'{k.capitalize() + ":":<16}', color, attrs=attrs),
limstr(v),
)
kvs = [
("task", cfg.task),
("train steps", f"{int(cfg.train_steps * cfg.action_repeat):,}"),
# ('observations', 'x'.join([str(s) for s in cfg.obs_shape])),
# ('actions', cfg.action_dim),
# ('experiment', cfg.exp_name),
]
if reward is not None:
kvs.append(
("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))
)
w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21
div = "-" * w
print(div)
for k, v in kvs:
pprint(k, v)
print(div)
def cfg_to_group(cfg, return_list=False):
"""Return a wandb-safe group name for logging. Optionally returns group name as list."""
lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)]
return lst if return_list else "-".join(lst)
class VideoRecorder:
"""Utility class for logging evaluation videos."""
def __init__(self, root_dir, wandb, render_size=384, fps=15):
self.save_dir = (root_dir / "eval_video") if root_dir else None
self._wandb = wandb
self.render_size = render_size
self.fps = fps
self.frames = []
self.enabled = False
self.camera_id = 0
def init(self, env, enabled=True):
self.frames = []
self.enabled = self.save_dir and self._wandb and enabled
try:
env_name = env.unwrapped.spec.id
except:
env_name = ""
if "maze2d" in env_name:
self.camera_id = -1
elif "quadruped" in env_name:
self.camera_id = 2
self.record(env)
def record(self, env):
if self.enabled:
frame = env.render(
mode="rgb_array",
height=self.render_size,
width=self.render_size,
camera_id=self.camera_id,
)
self.frames.append(frame)
def save(self, step):
if self.enabled:
frames = np.stack(self.frames).transpose(0, 3, 1, 2)
self._wandb.log(
{"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")},
step=step,
)
class Logger(object):
"""Primary logger object. Logs either locally or using wandb."""
def __init__(self, log_dir, cfg):
self._log_dir = make_dir(Path(log_dir))
self._model_dir = make_dir(self._log_dir / "models")
self._buffer_dir = make_dir(self._log_dir / "buffers")
self._save_model = cfg.save_model
self._save_buffer = cfg.save_buffer
self._group = cfg_to_group(cfg)
self._seed = cfg.seed
self._cfg = cfg
self._eval = []
print_run(cfg)
project, entity = cfg.get("wandb_project", "none"), cfg.get(
"wandb_entity", "none"
)
run_offline = (
not cfg.get("use_wandb", False) or project == "none" or entity == "none"
)
if run_offline:
print(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
try:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb.init(
project=project,
entity=entity,
name=str(cfg.seed),
notes=cfg.notes,
group=self._group,
tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"],
dir=self._log_dir,
config=OmegaConf.to_container(cfg, resolve=True),
)
print(
colored("Logs will be synced with wandb.", "blue", attrs=["bold"])
)
self._wandb = wandb
except:
print(
colored(
"Warning: failed to init wandb. Make sure `wandb_entity` is set to your username in `config.yaml`. Logs will be saved locally.",
"yellow",
attrs=["bold"],
)
)
self._wandb = None
self._video = (
VideoRecorder(log_dir, self._wandb)
if self._wandb and cfg.save_video
else None
)
@property
def video(self):
return self._video
def save_model(self, agent, identifier):
if self._save_model:
fp = self._model_dir / f"{str(identifier)}.pt"
agent.save(fp)
if self._wandb:
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
type="model",
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def save_buffer(self, buffer, identifier):
fp = self._buffer_dir / f"{str(identifier)}.pkl"
buffer.save(fp)
if self._wandb:
artifact = self._wandb.Artifact(
self._group + "-" + str(self._seed) + "-" + str(identifier),
type="buffer",
)
artifact.add_file(fp)
self._wandb.log_artifact(artifact)
def finish(self, agent, buffer):
if self._save_model:
self.save_model(agent, identifier="final")
if self._save_buffer:
self.save_buffer(buffer, identifier="buffer")
if self._wandb:
self._wandb.finish()
print_run(self._cfg, self._eval[-1][-1])
def _format(self, key, value, ty):
if ty == "int":
return f'{colored(key + ":", "grey")} {int(value):,}'
elif ty == "float":
return f'{colored(key + ":", "grey")} {value:.01f}'
elif ty == "time":
value = str(datetime.timedelta(seconds=int(value)))
return f'{colored(key + ":", "grey")} {value}'
else:
raise f"invalid log format type: {ty}"
def _print(self, d, category):
category = colored(category, "blue" if category == "train" else "green")
pieces = [f" {category:<14}"]
for k, disp_k, ty in CONSOLE_FORMAT:
pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}")
print(" ".join(pieces))
def log(self, d, category="train"):
assert category in {"train", "eval"}
if self._wandb is not None:
for k, v in d.items():
self._wandb.log({category + "/" + k: v}, step=d["env_step"])
if category == "eval":
# keys = ['env_step', 'avg_reward']
keys = ["env_step", "avg_reward", "pc_success"]
self._eval.append(np.array([d[key] for key in keys]))
pd.DataFrame(np.array(self._eval)).to_csv(
self._log_dir / "eval.log", header=keys, index=None
)
self._print(d, category)

View File

@@ -1,5 +1,6 @@
from copy import deepcopy
import einops
import numpy as np
import torch
import torch.nn as nn
@@ -90,7 +91,7 @@ class TDMPC(nn.Module):
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
@@ -308,9 +309,41 @@ class TDMPC(nn.Module):
self.demo_batch_size = 0
# Sample from interaction dataset
obs, next_obses, action, reward, mask, done, idxs, weights = (
replay_buffer.sample()
# to not have to mask
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
batch_size = self.cfg.horizon * self.cfg.batch_size
batch = replay_buffer.sample(batch_size)
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = (
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
.transpose(1, 0)
.contiguous()
)
batch = batch.to("cuda")
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
}
action = batch["action"]
next_obses = {
"rgb": batch["next", "observation", "image"].float(),
"state": batch["next", "observation", "state"],
}
reward = batch["next", "reward"]
reward = einops.rearrange(reward, "h t -> h t 1")
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["frame_id"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
# Sample from demonstration dataset
if self.demo_batch_size > 0:
@@ -341,6 +374,21 @@ class TDMPC(nn.Module):
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
# Apply augmentations
aug_tf = h.aug(self.cfg)
obs = aug_tf(obs)
for k in next_obses:
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
next_obses = aug_tf(next_obses)
for k in next_obses:
next_obses[k] = einops.rearrange(
next_obses[k],
"(h t) ... -> h t ...",
h=self.cfg.horizon,
t=self.cfg.batch_size,
)
horizon = self.cfg.horizon
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
@@ -407,6 +455,7 @@ class TDMPC(nn.Module):
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
@@ -415,13 +464,16 @@ class TDMPC(nn.Module):
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priorities(
idxs[: replay_buffer.cfg.batch_size],
priorities[: replay_buffer.cfg.batch_size],
# normalize between [0,1] to fit torchrl specification
priorities /= 1e4
priorities = priorities.clamp(max=1.0)
replay_buffer.update_priority(
idxs[: self.cfg.batch_size],
priorities[: self.cfg.batch_size],
)
if self.demo_batch_size > 0:
demo_buffer.update_priorities(
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
demo_buffer.update_priority(
demo_idxs, priorities[self.cfg.batch_size :]
)
# Update policy + target network

View File

@@ -306,13 +306,21 @@ class RandomShiftsAug(nn.Module):
x = F.pad(x, padding, "replicate")
eps = 1.0 / (h + 2 * self.pad)
arange = torch.linspace(
-1.0 + eps, 1.0 - eps, h + 2 * self.pad, device=x.device, dtype=x.dtype
-1.0 + eps,
1.0 - eps,
h + 2 * self.pad,
device=x.device,
dtype=torch.float32,
)[:h]
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
shift = torch.randint(
0, 2 * self.pad + 1, size=(n, 1, 1, 2), device=x.device, dtype=x.dtype
0,
2 * self.pad + 1,
size=(n, 1, 1, 2),
device=x.device,
dtype=torch.float32,
)
shift *= 2.0 / (h + 2 * self.pad)
grid = base_grid + shift