Wandb works, One output dir

This commit is contained in:
Cadene
2024-02-22 12:14:12 +00:00
parent ece89730e6
commit e3643d6146
11 changed files with 200 additions and 100 deletions

View File

@@ -96,7 +96,7 @@ class TDMPC(nn.Module):
self.model_target.eval()
self.batch_size = cfg.batch_size
self.step = 0
self.register_buffer("step", torch.zeros(1))
def state_dict(self):
"""Retrieve state dict of TOLD model, including slow-moving target network."""
@@ -122,7 +122,7 @@ class TDMPC(nn.Module):
"rgb": observation["image"],
"state": observation["state"],
}
return self.act(obs, t0=t0, step=self.step)
return self.act(obs, t0=t0, step=self.step.item())
@torch.no_grad()
def act(self, obs, t0=False, step=None):
@@ -513,5 +513,5 @@ class TDMPC(nn.Module):
metrics.update(value_info)
metrics.update(pi_update_info)
self.step = step
self.step[0] = step
return metrics