assume always there is state
This commit is contained in:
@@ -616,7 +616,6 @@ class VLAFlowMatching(nn.Module):
|
|||||||
num_lang_embs = lang_emb.shape[1]
|
num_lang_embs = lang_emb.shape[1]
|
||||||
att_masks += [0] * num_lang_embs
|
att_masks += [0] * num_lang_embs
|
||||||
|
|
||||||
if state is not None:
|
|
||||||
state_emb = self.state_proj(state)
|
state_emb = self.state_proj(state)
|
||||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||||
embs.append(state_emb)
|
embs.append(state_emb)
|
||||||
|
|||||||
Reference in New Issue
Block a user