offline training + online finetuning converge to 33 reward!
This commit is contained in:
@@ -129,19 +129,17 @@ class TDMPC(nn.Module):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
if isinstance(obs, dict):
|
||||
obs = {
|
||||
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
|
||||
k: o.detach().unsqueeze(0)
|
||||
for k, o in obs.items()
|
||||
}
|
||||
else:
|
||||
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
|
||||
0
|
||||
)
|
||||
obs = obs.detach().unsqueeze(0)
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a.cpu()
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z, actions, horizon):
|
||||
@@ -324,7 +322,7 @@ class TDMPC(nn.Module):
|
||||
# trajectory t = 256, horizon h = 5
|
||||
# (t h) ... -> h t ...
|
||||
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
|
||||
batch = batch.to("cuda")
|
||||
batch = batch.to(self.device)
|
||||
|
||||
FIRST_FRAME = 0
|
||||
obs = {
|
||||
@@ -469,7 +467,11 @@ class TDMPC(nn.Module):
|
||||
|
||||
weighted_loss = (total_loss.squeeze(1) * weights).mean()
|
||||
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
|
||||
weighted_loss.backward()
|
||||
has_nan = torch.isnan(weighted_loss).item()
|
||||
if has_nan:
|
||||
print(f"weighted_loss has nan: {total_loss=} {weights=}")
|
||||
else:
|
||||
weighted_loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
@@ -479,12 +481,16 @@ class TDMPC(nn.Module):
|
||||
if self.cfg.per:
|
||||
# Update priorities
|
||||
priorities = priority_loss.clamp(max=1e4).detach()
|
||||
replay_buffer.update_priority(
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
has_nan = torch.isnan(priorities).any().item()
|
||||
if has_nan:
|
||||
print(f"priorities has nan: {priorities=}")
|
||||
else:
|
||||
replay_buffer.update_priority(
|
||||
idxs[:num_slices],
|
||||
priorities[:num_slices],
|
||||
)
|
||||
if demo_batch_size > 0:
|
||||
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
|
||||
|
||||
# Update policy + target network
|
||||
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)
|
||||
|
||||
Reference in New Issue
Block a user