backup wip
This commit is contained in:
@@ -224,8 +224,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
if is_pad is not None:
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
breakpoint()
|
||||
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
a_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
|
||||
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = all_l1.mean() if is_pad is None else (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
@@ -240,5 +239,5 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
loss_dict["loss"] = loss_dict["l1"]
|
||||
return loss_dict
|
||||
else:
|
||||
action, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
action, _ = self.model(qpos, image, env_state) # no action, sample from prior
|
||||
return action
|
||||
|
||||
Reference in New Issue
Block a user