[Fix] Device Error on SmolVLA Multi-GPU Training (#2270)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Hakjin Lee
2025-10-21 21:26:31 +09:00
committed by GitHub
parent abe9e79825
commit 63cd2111ad
2 changed files with 3 additions and 1 deletions

View File

@@ -485,6 +485,7 @@ class VLAFlowMatching(nn.Module):
num_vlm_layers=self.config.num_vlm_layers,
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
expert_width_multiplier=self.config.expert_width_multiplier,
device=self.config.device,
)
self.state_proj = nn.Linear(
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size

View File

@@ -70,13 +70,14 @@ class SmolVLMWithExpertModel(nn.Module):
num_vlm_layers: int = -1,
self_attn_every_n_layers: int = -1,
expert_width_multiplier: float = 0.5,
device: str = "auto",
):
super().__init__()
if load_vlm_weights:
print(f"Loading {model_id} weights ...")
self.vlm = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
device_map=device,
torch_dtype="bfloat16",
low_cpu_mem_usage=True,
)