diff --git a/lerobot/common/tdmpc.py b/lerobot/common/tdmpc.py index 03aeb7d70..de4543dba 100644 --- a/lerobot/common/tdmpc.py +++ b/lerobot/common/tdmpc.py @@ -128,10 +128,7 @@ class TDMPC(nn.Module): def act(self, obs, t0=False, step=None): """Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag.""" if isinstance(obs, dict): - obs = { - k: o.detach().unsqueeze(0) - for k, o in obs.items() - } + obs = {k: o.detach().unsqueeze(0) for k, o in obs.items()} else: obs = obs.detach().unsqueeze(0) z = self.model.encode(obs)