[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -406,14 +416,18 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self.action_head = nn.Linear(
config.dim_model, self.config.action_feature.shape[0]
)
self._reset_parameters()
@@ -461,14 +475,20 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -517,7 +537,9 @@ class ACT(nn.Module):
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
@@ -534,7 +556,9 @@ class ACT(nn.Module):
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).