revert dp changes, make act and tdmpc batch friendly

This commit is contained in:
Alexander Soare
2024-03-18 19:18:21 +00:00
parent 09ddd9bf92
commit 88347965c2
8 changed files with 32 additions and 58 deletions

View File

@@ -128,11 +128,6 @@ class TDMPC(AbstractPolicy):
def select_action(self, observation, step_count):
t0 = step_count.item() == 0
# TODO(rcadene): remove unsqueeze hack...
if observation["image"].ndim == 3:
observation["image"] = observation["image"].unsqueeze(0)
observation["state"] = observation["state"].unsqueeze(0)
obs = {
# TODO(rcadene): remove contiguous hack...
"rgb": observation["image"].contiguous(),
@@ -149,7 +144,7 @@ class TDMPC(AbstractPolicy):
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
a = self.model.pi(z, self.cfg.min_std * self.model.training)
return a
@torch.no_grad()