fix(mps): gradient exploding and nan loss issues with ACT (#1490)
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -485,12 +485,10 @@ class ACT(nn.Module):
|
|||||||
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
self.encoder_env_state_input_proj(batch["observation.environment_state"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Camera observation features and positional embeddings.
|
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
all_cam_features = []
|
|
||||||
all_cam_pos_embeds = []
|
|
||||||
|
|
||||||
# For a list of images, the H and W may vary but H*W is constant.
|
# For a list of images, the H and W may vary but H*W is constant.
|
||||||
|
# NOTE: If modifying this section, verify on MPS devices that
|
||||||
|
# gradients remain stable (no explosions or NaNs).
|
||||||
for img in batch["observation.images"]:
|
for img in batch["observation.images"]:
|
||||||
cam_features = self.backbone(img)["feature_map"]
|
cam_features = self.backbone(img)["feature_map"]
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
@@ -500,11 +498,10 @@ class ACT(nn.Module):
|
|||||||
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||||
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||||
|
|
||||||
all_cam_features.append(cam_features)
|
# Extend immediately instead of accumulating and concatenating
|
||||||
all_cam_pos_embeds.append(cam_pos_embed)
|
# Convert to list to extend properly
|
||||||
|
encoder_in_tokens.extend(list(cam_features))
|
||||||
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
encoder_in_pos_embed.extend(list(cam_pos_embed))
|
||||||
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
|
||||||
|
|
||||||
# Stack all tokens along the sequence dimension.
|
# Stack all tokens along the sequence dimension.
|
||||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
@@ -207,7 +207,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
|
||||||
|
|
||||||
train_tracker, output_dict = update_policy(
|
train_tracker, output_dict = update_policy(
|
||||||
train_tracker,
|
train_tracker,
|
||||||
|
|||||||
Reference in New Issue
Block a user