pre-commit run -a

This commit is contained in:
Remi Cadene
2024-03-02 15:58:21 +00:00
parent 1ae6205269
commit 45b4ecb727
6 changed files with 44 additions and 43 deletions

View File

@@ -5,7 +5,6 @@ from copy import deepcopy
import einops
import numpy as np
from tensordict import TensorDict
import torch
import torch.nn as nn
@@ -127,7 +126,7 @@ class TDMPC(nn.Module):
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
@@ -147,10 +146,7 @@ class TDMPC(nn.Module):
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {k: o.detach() for k, o in obs.items()}
else:
obs = obs.detach()
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)