pre-commit run -a
This commit is contained in:
@@ -5,6 +5,7 @@ import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
|
||||
from diffusion_policy.model.common.lr_scheduler import get_scheduler
|
||||
from diffusion_policy.model.vision.model_getter import get_resnet
|
||||
from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
|
||||
@@ -5,7 +5,6 @@ from copy import deepcopy
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
from tensordict import TensorDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -127,7 +126,7 @@ class TDMPC(nn.Module):
|
||||
@torch.no_grad()
|
||||
def forward(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
@@ -147,10 +146,7 @@ class TDMPC(nn.Module):
|
||||
@torch.no_grad()
|
||||
def act(self, obs, t0=False, step=None):
|
||||
"""Take an action. Uses either MPC or the learned policy, depending on the self.cfg.mpc flag."""
|
||||
if isinstance(obs, dict):
|
||||
obs = {k: o.detach() for k, o in obs.items()}
|
||||
else:
|
||||
obs = obs.detach()
|
||||
obs = {k: o.detach() for k, o in obs.items()} if isinstance(obs, dict) else obs.detach()
|
||||
z = self.model.encode(obs)
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
|
||||
Reference in New Issue
Block a user