Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act

This commit is contained in:
Alexander Soare
2024-04-09 08:36:28 +01:00
13 changed files with 109 additions and 247 deletions

View File

@@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
)
def __init__(self, cfg, device, n_action_steps=1):
def __init__(self, cfg, device):
"""
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
"""
@@ -73,7 +73,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if getattr(cfg, "n_obs_steps", 1) != 1:
raise ValueError(self._multiple_obs_steps_not_handled_msg)
self.cfg = cfg
self.n_action_steps = n_action_steps
self.n_action_steps = cfg.n_action_steps
self.device = get_safe_torch_device(device)
self.camera_names = cfg.camera_names
self.use_vae = cfg.use_vae
@@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
if self.n_action_steps is not None:
self._action_queue = deque([], maxlen=self.n_action_steps)
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
"""
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
@@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
# the image index dimension.
def update(self, batch, *_) -> dict:
def update(self, batch, *_, **__) -> dict:
start_time = time.time()
self._preprocess_batch(batch)
@@ -311,7 +311,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def _forward(
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
"""
Args:
robot_state: (B, J) batch of robot joint configurations.
@@ -344,16 +344,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder.
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
# Sample the latent with the reparameterization trick.
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.