Eval reproduced! Train running (but not reproduced)

This commit is contained in:
Cadene
2024-02-10 15:46:24 +00:00
parent 937b2f8cba
commit 228c045674
14 changed files with 787 additions and 118 deletions

View File

@@ -1,5 +1,6 @@
from copy import deepcopy
import einops
import numpy as np
import torch
import torch.nn as nn
@@ -90,7 +91,7 @@ class TDMPC(nn.Module):
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.model.eval()
self.model_target.eval()
self.batch_size = cfg.batch_size
@@ -308,9 +309,41 @@ class TDMPC(nn.Module):
self.demo_batch_size = 0
# Sample from interaction dataset
obs, next_obses, action, reward, mask, done, idxs, weights = (
replay_buffer.sample()
# to not have to mask
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
batch_size = self.cfg.horizon * self.cfg.batch_size
batch = replay_buffer.sample(batch_size)
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = (
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
.transpose(1, 0)
.contiguous()
)
batch = batch.to("cuda")
FIRST_FRAME = 0
obs = {
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
"state": batch["observation", "state"][FIRST_FRAME],
}
action = batch["action"]
next_obses = {
"rgb": batch["next", "observation", "image"].float(),
"state": batch["next", "observation", "state"],
}
reward = batch["next", "reward"]
reward = einops.rearrange(reward, "h t -> h t 1")
# We dont use `batch["next", "done"]` since it only indicates the end of an
# episode, but not the end of the trajectory of an episode.
# Neither does `batch["next", "terminated"]`
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
idxs = batch["frame_id"][FIRST_FRAME]
weights = batch["_weight"][FIRST_FRAME, :, None]
# Sample from demonstration dataset
if self.demo_batch_size > 0:
@@ -341,6 +374,21 @@ class TDMPC(nn.Module):
idxs = torch.cat([idxs, demo_idxs])
weights = torch.cat([weights, demo_weights])
# Apply augmentations
aug_tf = h.aug(self.cfg)
obs = aug_tf(obs)
for k in next_obses:
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
next_obses = aug_tf(next_obses)
for k in next_obses:
next_obses[k] = einops.rearrange(
next_obses[k],
"(h t) ... -> h t ...",
h=self.cfg.horizon,
t=self.cfg.batch_size,
)
horizon = self.cfg.horizon
loss_mask = torch.ones_like(mask, device=self.device)
for t in range(1, horizon):
@@ -407,6 +455,7 @@ class TDMPC(nn.Module):
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
@@ -415,13 +464,16 @@ class TDMPC(nn.Module):
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priorities(
idxs[: replay_buffer.cfg.batch_size],
priorities[: replay_buffer.cfg.batch_size],
# normalize between [0,1] to fit torchrl specification
priorities /= 1e4
priorities = priorities.clamp(max=1.0)
replay_buffer.update_priority(
idxs[: self.cfg.batch_size],
priorities[: self.cfg.batch_size],
)
if self.demo_batch_size > 0:
demo_buffer.update_priorities(
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
demo_buffer.update_priority(
demo_idxs, priorities[self.cfg.batch_size :]
)
# Update policy + target network