assume always there is state
This commit is contained in:
@@ -616,19 +616,18 @@ class VLAFlowMatching(nn.Module):
|
|||||||
num_lang_embs = lang_emb.shape[1]
|
num_lang_embs = lang_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs
|
att_masks += [0] * num_lang_embs
|
||||||
|
|
||||||
if state is not None:
|
state_emb = self.state_proj(state)
|
||||||
state_emb = self.state_proj(state)
|
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
embs.append(state_emb)
|
||||||
embs.append(state_emb)
|
bsize = state_emb.shape[0]
|
||||||
bsize = state_emb.shape[0]
|
device = state_emb.device
|
||||||
device = state_emb.device
|
|
||||||
|
|
||||||
states_seq_len = state_emb.shape[1]
|
states_seq_len = state_emb.shape[1]
|
||||||
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
|
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
|
||||||
pad_masks.append(state_mask)
|
pad_masks.append(state_mask)
|
||||||
|
|
||||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||||
att_masks += [1] * (states_seq_len)
|
att_masks += [1] * (states_seq_len)
|
||||||
embs = torch.cat(embs, dim=1)
|
embs = torch.cat(embs, dim=1)
|
||||||
pad_masks = torch.cat(pad_masks, dim=1)
|
pad_masks = torch.cat(pad_masks, dim=1)
|
||||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||||
|
|||||||
Reference in New Issue
Block a user