revert dp changes, make act and tdmpc batch friendly

This commit is contained in:
Alexander Soare
2024-03-18 19:18:21 +00:00
parent 09ddd9bf92
commit 88347965c2
8 changed files with 32 additions and 58 deletions

View File

@@ -153,10 +153,6 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
self.eval()
# TODO(rcadene): remove unsqueeze hack to add bsize=1
observation["image", "top"] = observation["image", "top"].unsqueeze(0)
# observation["state"] = observation["state"].unsqueeze(0)
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
@@ -180,11 +176,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# remove bsize=1
action = action.squeeze(0)
# take first predicted action or n first actions
action = action[0] if self.n_action_steps == 1 else action[: self.n_action_steps]
action = action[: self.n_action_steps]
return action
def _forward(self, qpos, image, actions=None, is_pad=None):