This commit is contained in:
Cadene
2024-02-18 01:24:19 +00:00
parent a5c305a7a4
commit fdfb2010fd

View File

@@ -128,10 +128,7 @@ class TDMPC(nn.Module):
def act(self, obs, t0=False, step=None): def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict): if isinstance(obs, dict):
obs = { obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()}
k: o.detach().unsqueeze(0)
for k, o in obs.items()
}
else: else:
obs = obs.detach().unsqueeze(0) obs = obs.detach().unsqueeze(0)
z = self.model.encode(obs) z = self.model.encode(obs)