[HIL-SERL]Remove overstrict pre-commit modifications (#1028)

This commit is contained in:
Adil Zouitine
2025-04-24 13:48:52 +02:00
committed by GitHub
parent 671ac3411f
commit c58b504a9e
47 changed files with 163 additions and 757 deletions

View File

@@ -241,9 +241,7 @@ class ACTTemporalEnsembler:
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1),
dtype=torch.long,
device=self.ensembled_actions.device,
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
@@ -255,10 +253,7 @@ class ACTTemporalEnsembler:
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[
self.ensembled_actions_count,
torch.ones_like(self.ensembled_actions_count[-1:]),
]
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
@@ -338,11 +333,7 @@ class ACT(nn.Module):
# Backbone for image feature extraction.
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[
False,
False,
config.replace_final_stride_with_dilation,
],
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
norm_layer=FrozenBatchNorm2d,
)
@@ -436,11 +427,7 @@ class ACT(nn.Module):
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -553,10 +540,7 @@ class ACTEncoder(nn.Module):
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self,
x: Tensor,
pos_embed: Tensor | None = None,
key_padding_mask: Tensor | None = None,
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
@@ -619,10 +603,7 @@ class ACTDecoder(nn.Module):
) -> Tensor:
for layer in self.layers:
x = layer(
x,
encoder_out,
decoder_pos_embed=decoder_pos_embed,
encoder_pos_embed=encoder_pos_embed,
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
if self.norm is not None:
x = self.norm(x)