backup wip

This commit is contained in:
Alexander Soare
2024-04-11 17:51:35 +01:00
parent 91ff69d64c
commit 976a197f98
26 changed files with 661 additions and 2733 deletions

View File

@@ -176,7 +176,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
"""
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
@@ -188,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
self._action_queue.extend(self.select_actions(batch).transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad()
@torch.no_grad
def select_actions(self, batch: dict[str, Tensor]) -> Tensor:
"""Use the action chunking transformer to generate a sequence of actions."""
self.eval()
@@ -223,8 +224,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
{
"observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration).
"observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images.
"action": (B, H, J) tensor of actions (positional target for robot joint configuration)
"action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds.
}
"""
if add_obs_steps_dim:
@@ -244,7 +243,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, *_, **__) -> dict:
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self._preprocess_batch(batch)
@@ -278,6 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
return info
def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor:
"""A forward pass through the DNN part of this policy with optional loss computation."""
images = self.image_normalizer(batch["observation.images.top"])
if return_loss: # training time