From 91b110d8063afdf7f5086aad0be8eea0ac939892 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 15 Jul 2025 10:28:19 +0200 Subject: [PATCH] fix(mps): gradient exploding and nan loss issues with ACT (#1490) Co-authored-by: Michel Aractingi --- src/lerobot/policies/act/modeling_act.py | 15 ++++++--------- src/lerobot/scripts/train.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index f66c8ae8..aa81d3cd 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -485,12 +485,10 @@ class ACT(nn.Module): self.encoder_env_state_input_proj(batch["observation.environment_state"]) ) - # Camera observation features and positional embeddings. if self.config.image_features: - all_cam_features = [] - all_cam_pos_embeds = [] - # For a list of images, the H and W may vary but H*W is constant. + # NOTE: If modifying this section, verify on MPS devices that + # gradients remain stable (no explosions or NaNs). for img in batch["observation.images"]: cam_features = self.backbone(img)["feature_map"] cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) @@ -500,11 +498,10 @@ class ACT(nn.Module): cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") - all_cam_features.append(cam_features) - all_cam_pos_embeds.append(cam_pos_embed) - - encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0)) - encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0)) + # Extend immediately instead of accumulating and concatenating + # Convert to list to extend properly + encoder_in_tokens.extend(list(cam_features)) + encoder_in_pos_embed.extend(list(cam_pos_embed)) # Stack all tokens along the sequence dimension. encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 2f2e88de..f09d231a 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -180,7 +180,7 @@ def train(cfg: TrainPipelineConfig): batch_size=cfg.batch_size, shuffle=shuffle, sampler=sampler, - pin_memory=device.type != "cpu", + pin_memory=device.type == "cuda", drop_last=False, ) dl_iter = cycle(dataloader) @@ -207,7 +207,7 @@ def train(cfg: TrainPipelineConfig): for key in batch: if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].to(device, non_blocking=True) + batch[key] = batch[key].to(device, non_blocking=device.type == "cuda") train_tracker, output_dict = update_policy( train_tracker,