fix(mps): gradient exploding and nan loss issues with ACT (#1490)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Adil Zouitine
2025-07-15 10:28:19 +02:00
committed by GitHub
parent 519b76110e
commit 91b110d806
2 changed files with 8 additions and 11 deletions

View File

@@ -485,12 +485,10 @@ class ACT(nn.Module):
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []
# For a list of images, the H and W may vary but H*W is constant.
# NOTE: If modifying this section, verify on MPS devices that
# gradients remain stable (no explosions or NaNs).
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
@@ -500,11 +498,10 @@ class ACT(nn.Module):
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
all_cam_features.append(cam_features)
all_cam_pos_embeds.append(cam_pos_embed)
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
# Extend immediately instead of accumulating and concatenating
# Convert to list to extend properly
encoder_in_tokens.extend(list(cam_features))
encoder_in_pos_embed.extend(list(cam_pos_embed))
# Stack all tokens along the sequence dimension.
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)

View File

@@ -180,7 +180,7 @@ def train(cfg: TrainPipelineConfig):
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
pin_memory=device.type == "cuda",
drop_last=False,
)
dl_iter = cycle(dataloader)
@@ -207,7 +207,7 @@ def train(cfg: TrainPipelineConfig):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,