offline training + online finetuning converge to 33 reward!

This commit is contained in:
Cadene
2024-02-18 01:23:44 +00:00
parent 0b4084f0f8
commit a5c305a7a4
3 changed files with 42 additions and 36 deletions

View File

@@ -129,19 +129,17 @@ class TDMPC(nn.Module):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {
k: torch.tensor(o, dtype=torch.float32, device=self.device).unsqueeze(0)
k: o.detach().unsqueeze(0)
for k, o in obs.items()
}
else:
obs = torch.tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(
0
)
obs = obs.detach().unsqueeze(0)
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a.cpu()
return a
@torch.no_grad()
def estimate_value(self, z, actions, horizon):
@@ -324,7 +322,7 @@ class TDMPC(nn.Module):
# trajectory t = 256, horizon h = 5
# (t h) ... -> h t ...
batch = batch.reshape(num_slices, horizon).transpose(1, 0).contiguous()
batch = batch.to("cuda")
batch = batch.to(self.device)
FIRST_FRAME = 0
obs = {
@@ -469,7 +467,11 @@ class TDMPC(nn.Module):
weighted_loss = (total_loss.squeeze(1) * weights).mean()
weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon))
weighted_loss.backward()
has_nan = torch.isnan(weighted_loss).item()
if has_nan:
print(f"weighted_loss has nan: {total_loss=} {weights=}")
else:
weighted_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
@@ -479,12 +481,16 @@ class TDMPC(nn.Module):
if self.cfg.per:
# Update priorities
priorities = priority_loss.clamp(max=1e4).detach()
replay_buffer.update_priority(
idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
has_nan = torch.isnan(priorities).any().item()
if has_nan:
print(f"priorities has nan: {priorities=}")
else:
replay_buffer.update_priority(
idxs[:num_slices],
priorities[:num_slices],
)
if demo_batch_size > 0:
demo_buffer.update_priority(demo_idxs, priorities[num_slices:])
# Update policy + target network
_, pi_update_info = self.update_pi(zs[:-1].detach(), acts=action)