pre-commit run -a

This commit is contained in:
Remi Cadene
2024-03-02 15:58:21 +00:00
parent 1ae6205269
commit 45b4ecb727
6 changed files with 44 additions and 43 deletions

View File

@@ -5,6 +5,7 @@ import hydra
import torch
import torch.nn as nn
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusion_policy.model.common.lr_scheduler import get_scheduler
from diffusion_policy.model.vision.model_getter import get_resnet
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder

View File

@@ -5,7 +5,6 @@ from copy import deepcopy
import einops
import numpy as np
from tensordict import TensorDict
import torch
import torch.nn as nn
@@ -127,7 +126,7 @@ class TDMPC(nn.Module):
@torch.no_grad()
def forward(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
@@ -147,10 +146,7 @@ class TDMPC(nn.Module):
@torch.no_grad()
def act(self, obs, t0=False, step=None):
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
if isinstance(obs, dict):
obs = {k: o.detach() for k, o in obs.items()}
else:
obs = obs.detach()
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
z = self.model.encode(obs)
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)