Wandb works, One output dir
This commit is contained in:
@@ -96,7 +96,7 @@ class TDMPC(nn.Module):
|
||||
self.model_target.eval()
|
||||
self.batch_size = cfg.batch_size
|
||||
|
||||
self.step = 0
|
||||
self.register_buffer("step", torch.zeros(1))
|
||||
|
||||
def state_dict(self):
|
||||
"""Retrieve state dict of TOLD model, including slow-moving target network."""
|
||||
@@ -122,7 +122,7 @@ class TDMPC(nn.Module):
|
||||
"rgb": observation["image"],
|
||||
"state": observation["state"],
|
||||
}
|
||||
return self.act(obs, t0=t0, step=self.step)
|
||||
return self.act(obs, t0=t0, step=self.step.item())
|
||||
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
@@ -513,5 +513,5 @@ class TDMPC(nn.Module):
|
||||
metrics.update(value_info)
|
||||
metrics.update(pi_update_info)
|
||||
|
||||
self.step = step
|
||||
self.step[0] = step
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user