diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index 57506b7a..ac56205a 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -616,19 +616,18 @@ class VLAFlowMatching(nn.Module): num_lang_embs = lang_emb.shape[1] att_masks += [0] * num_lang_embs - if state is not None: - state_emb = self.state_proj(state) - state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb - embs.append(state_emb) - bsize = state_emb.shape[0] - device = state_emb.device + state_emb = self.state_proj(state) + state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + embs.append(state_emb) + bsize = state_emb.shape[0] + device = state_emb.device - states_seq_len = state_emb.shape[1] - state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) - pad_masks.append(state_mask) + states_seq_len = state_emb.shape[1] + state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) + pad_masks.append(state_mask) - # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] * (states_seq_len) + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] * (states_seq_len) embs = torch.cat(embs, dim=1) pad_masks = torch.cat(pad_masks, dim=1) att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)