revert dp changes, make act and tdmpc batch friendly
This commit is contained in:
@@ -128,11 +128,6 @@ class TDMPC(AbstractPolicy):
|
||||
def select_action(self, observation, step_count):
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
# TODO(rcadene): remove unsqueeze hack...
|
||||
if observation["image"].ndim == 3:
|
||||
observation["image"] = observation["image"].unsqueeze(0)
|
||||
observation["state"] = observation["state"].unsqueeze(0)
|
||||
|
||||
obs = {
|
||||
# TODO(rcadene): remove contiguous hack...
|
||||
"rgb": observation["image"].contiguous(),
|
||||
@@ -149,7 +144,7 @@ class TDMPC(AbstractPolicy):
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training)
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user