backup wip
This commit is contained in:
@@ -176,7 +176,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
"""
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
@@ -188,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.no_grad
|
||||
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Use the action chunking transformer to generate a sequence of actions."""
|
||||
self.eval()
|
||||
@@ -223,8 +224,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
{
|
||||
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
|
||||
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
|
||||
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
|
||||
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
|
||||
}
|
||||
"""
|
||||
if add_obs_steps_dim:
|
||||
@@ -244,7 +243,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, *_, **__) -> dict:
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
@@ -278,6 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
return info
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
|
||||
"""A forward pass through the DNN part of this policy with optional loss computation."""
|
||||
images = self.image_normalizer(batch["observation.images.top"])
|
||||
|
||||
if return_loss: # training time
|
||||
|
||||
Reference in New Issue
Block a user