backup wip

This commit is contained in:
Alexander Soare
2024-04-03 14:21:07 +01:00
parent c7d70a8db9
commit 110ac5ffa1
6 changed files with 182 additions and 191 deletions

View File

@@ -224,8 +224,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
if is_pad is not None:
is_pad = is_pad[:, : self.model.num_queries]
breakpoint()
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
a_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
@@ -240,5 +239,5 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
loss_dict["loss"] = loss_dict["l1"]
return loss_dict
else:
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
action, _ = self.model(qpos, image, env_state) # no action, sample from prior
return action