Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act
This commit is contained in:
@@ -65,7 +65,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg, device, n_action_steps=1):
|
||||
def __init__(self, cfg, device):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
"""
|
||||
@@ -73,7 +73,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if getattr(cfg, "n_obs_steps", 1) != 1:
|
||||
raise ValueError(self._multiple_obs_steps_not_handled_msg)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.n_action_steps = cfg.n_action_steps
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.camera_names = cfg.camera_names
|
||||
self.use_vae = cfg.use_vae
|
||||
@@ -176,7 +176,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
if self.n_action_steps is not None:
|
||||
self._action_queue = deque([], maxlen=self.n_action_steps)
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], *_) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor:
|
||||
"""
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
@@ -244,7 +244,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get
|
||||
# the image index dimension.
|
||||
|
||||
def update(self, batch, *_) -> dict:
|
||||
def update(self, batch, *_, **__) -> dict:
|
||||
start_time = time.time()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
@@ -311,7 +311,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
def _forward(
|
||||
self, robot_state: Tensor, image: Tensor, actions: Tensor | None = None
|
||||
) -> tuple[Tensor, tuple[Tensor, Tensor]]:
|
||||
) -> tuple[Tensor, tuple[Tensor | None, Tensor | None]]:
|
||||
"""
|
||||
Args:
|
||||
robot_state: (B, J) batch of robot joint configurations.
|
||||
@@ -344,16 +344,16 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
|
||||
# Forward pass through VAE encoder.
|
||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
|
||||
Reference in New Issue
Block a user