|
|
|
|
@@ -110,7 +110,6 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
# self.bc_optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
|
|
|
|
self.model.eval()
|
|
|
|
|
self.model_target.eval()
|
|
|
|
|
self.batch_size = cfg.batch_size
|
|
|
|
|
|
|
|
|
|
self.register_buffer("step", torch.zeros(1))
|
|
|
|
|
|
|
|
|
|
@@ -325,7 +324,7 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
def _td_target(self, next_z, reward, mask):
|
|
|
|
|
"""Compute the TD-target from a reward and the observation at the following time step."""
|
|
|
|
|
next_v = self.model.V(next_z)
|
|
|
|
|
td_target = reward + self.cfg.discount * mask * next_v
|
|
|
|
|
td_target = reward + self.cfg.discount * mask * next_v.squeeze(2)
|
|
|
|
|
return td_target
|
|
|
|
|
|
|
|
|
|
def forward(self, batch, step):
|
|
|
|
|
@@ -420,6 +419,8 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
# idxs = torch.cat([idxs, demo_idxs])
|
|
|
|
|
# weights = torch.cat([weights, demo_weights])
|
|
|
|
|
|
|
|
|
|
batch_size = batch["index"].shape[0]
|
|
|
|
|
|
|
|
|
|
# TODO(rcadene): convert tdmpc with (batch size, time/horizon, channels)
|
|
|
|
|
# instead of currently (time/horizon, batch size, channels) which is not the pytorch convention
|
|
|
|
|
# batch size b = 256, time/horizon t = 5
|
|
|
|
|
@@ -433,7 +434,7 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
|
|
|
|
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
|
|
|
|
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
|
|
|
|
weights = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
|
|
|
|
weights = torch.ones(batch_size, dtype=torch.bool, device=reward.device)
|
|
|
|
|
|
|
|
|
|
obses = {
|
|
|
|
|
"rgb": batch["observation.image"],
|
|
|
|
|
@@ -476,7 +477,7 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
td_targets = self._td_target(next_z, reward, mask)
|
|
|
|
|
|
|
|
|
|
# Latent rollout
|
|
|
|
|
zs = torch.empty(horizon + 1, self.batch_size, self.cfg.latent_dim, device=self.device)
|
|
|
|
|
zs = torch.empty(horizon + 1, batch_size, self.cfg.latent_dim, device=self.device)
|
|
|
|
|
reward_preds = torch.empty_like(reward, device=self.device)
|
|
|
|
|
assert reward.shape[0] == horizon
|
|
|
|
|
z = self.model.encode(obs)
|
|
|
|
|
@@ -485,22 +486,21 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
for t in range(horizon):
|
|
|
|
|
z, reward_pred = self.model.next(z, action[t])
|
|
|
|
|
zs[t + 1] = z
|
|
|
|
|
reward_preds[t] = reward_pred
|
|
|
|
|
reward_preds[t] = reward_pred.squeeze(1)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
v_target = self.model_target.Q(zs[:-1].detach(), action, return_type="min")
|
|
|
|
|
|
|
|
|
|
# Predictions
|
|
|
|
|
qs = self.model.Q(zs[:-1], action, return_type="all")
|
|
|
|
|
qs = qs.squeeze(3)
|
|
|
|
|
value_info["Q"] = qs.mean().item()
|
|
|
|
|
v = self.model.V(zs[:-1])
|
|
|
|
|
value_info["V"] = v.mean().item()
|
|
|
|
|
|
|
|
|
|
# Losses
|
|
|
|
|
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1, 1)
|
|
|
|
|
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2, keepdim=True) * loss_mask).sum(
|
|
|
|
|
dim=0
|
|
|
|
|
)
|
|
|
|
|
rho = torch.pow(self.cfg.rho, torch.arange(horizon, device=self.device)).view(-1, 1)
|
|
|
|
|
consistency_loss = (rho * torch.mean(h.mse(zs[1:], z_targets), dim=2) * loss_mask).sum(dim=0)
|
|
|
|
|
reward_loss = (rho * h.mse(reward_preds, reward) * loss_mask).sum(dim=0)
|
|
|
|
|
q_value_loss, priority_loss = 0, 0
|
|
|
|
|
for q in range(self.cfg.num_q):
|
|
|
|
|
@@ -508,7 +508,9 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
priority_loss += (rho * h.l1(qs[q], td_targets) * loss_mask).sum(dim=0)
|
|
|
|
|
|
|
|
|
|
expectile = h.linear_schedule(self.cfg.expectile, step)
|
|
|
|
|
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile) * loss_mask).sum(dim=0)
|
|
|
|
|
v_value_loss = (rho * h.l2_expectile(v_target - v, expectile=expectile).squeeze(2) * loss_mask).sum(
|
|
|
|
|
dim=0
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
total_loss = (
|
|
|
|
|
self.cfg.consistency_coef * consistency_loss
|
|
|
|
|
@@ -517,7 +519,7 @@ class TDMPCPolicy(nn.Module):
|
|
|
|
|
+ self.cfg.value_coef * v_value_loss
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
|
|
|
|
weighted_loss = (total_loss * weights).mean()
|
|
|
|
|
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
|
|
|
|
has_nan = torch.isnan(weighted_loss).item()
|
|
|
|
|
if has_nan:
|
|
|
|
|
|