Eval reproduced! Train running (but not reproduced)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -90,7 +91,7 @@ class TDMPC(nn.Module):
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.model.eval()
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
@@ -308,9 +309,41 @@ class TDMPC(nn.Module):
|
||||
self.demo_batch_size = 0
|
||||
|
||||
# Sample from interaction dataset
|
||||
obs, next_obses, action, reward, mask, done, idxs, weights = (
|
||||
replay_buffer.sample()
|
||||
|
||||
# to not have to mask
|
||||
# batch_size = (self.cfg.batch_size // self.cfg.horizon) * self.cfg.horizon
|
||||
batch_size = self.cfg.horizon * self.cfg.batch_size
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = (
|
||||
batch.reshape(self.cfg.batch_size, self.cfg.horizon)
|
||||
.transpose(1, 0)
|
||||
.contiguous()
|
||||
)
|
||||
batch = batch.to("cuda")
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
"rgb": batch["observation", "image"][FIRST_FRAME].float(),
|
||||
"state": batch["observation", "state"][FIRST_FRAME],
|
||||
}
|
||||
action = batch["action"]
|
||||
next_obses = {
|
||||
"rgb": batch["next", "observation", "image"].float(),
|
||||
"state": batch["next", "observation", "state"],
|
||||
}
|
||||
reward = batch["next", "reward"]
|
||||
reward = einops.rearrange(reward, "h t -> h t 1")
|
||||
# We dont use `batch["next", "done"]` since it only indicates the end of an
|
||||
# episode, but not the end of the trajectory of an episode.
|
||||
# Neither does `batch["next", "terminated"]`
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
idxs = batch["frame_id"][FIRST_FRAME]
|
||||
weights = batch["_weight"][FIRST_FRAME, :, None]
|
||||
|
||||
# Sample from demonstration dataset
|
||||
if self.demo_batch_size > 0:
|
||||
@@ -341,6 +374,21 @@ class TDMPC(nn.Module):
|
||||
idxs = torch.cat([idxs, demo_idxs])
|
||||
weights = torch.cat([weights, demo_weights])
|
||||
|
||||
# Apply augmentations
|
||||
aug_tf = h.aug(self.cfg)
|
||||
obs = aug_tf(obs)
|
||||
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(next_obses[k], "h t ... -> (h t) ...")
|
||||
next_obses = aug_tf(next_obses)
|
||||
for k in next_obses:
|
||||
next_obses[k] = einops.rearrange(
|
||||
next_obses[k],
|
||||
"(h t) ... -> h t ...",
|
||||
h=self.cfg.horizon,
|
||||
t=self.cfg.batch_size,
|
||||
)
|
||||
|
||||
horizon = self.cfg.horizon
|
||||
loss_mask = torch.ones_like(mask, device=self.device)
|
||||
for t in range(1, horizon):
|
||||
@@ -407,6 +455,7 @@ class TDMPC(nn.Module):
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
weighted_loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
@@ -415,13 +464,16 @@ class TDMPC(nn.Module):
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
replay_buffer.update_priorities(
|
||||
idxs[: replay_buffer.cfg.batch_size],
|
||||
priorities[: replay_buffer.cfg.batch_size],
|
||||
# normalize between [0,1] to fit torchrl specification
|
||||
priorities /= 1e4
|
||||
priorities = priorities.clamp(max=1.0)
|
||||
replay_buffer.update_priority(
|
||||
idxs[: self.cfg.batch_size],
|
||||
priorities[: self.cfg.batch_size],
|
||||
)
|
||||
if self.demo_batch_size > 0:
|
||||
demo_buffer.update_priorities(
|
||||
demo_idxs, priorities[replay_buffer.cfg.batch_size :]
|
||||
demo_buffer.update_priority(
|
||||
demo_idxs, priorities[self.cfg.batch_size :]
|
||||
)
|
||||
|
||||
# Update policy + target network
|
||||
|
||||
Reference in New Issue
Block a user