smolfix(vla): typing and fix offline inference when action in the batch (#1597)

This commit is contained in:
Adil Zouitine
2025-07-28 11:44:22 +02:00
committed by GitHub
parent f089ab3628
commit 615adfc48d

View File

@@ -384,8 +384,13 @@ class SmolVLAPolicy(PreTrainedPolicy):
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
# TODO: Check if this for loop is needed.
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
# In the case of offline inference, we have the action in the batch
# that why without the k != ACTION check, it will raise an error because we are trying to stack
# on an empty container.
for k in batch:
if k in self._queues:
if k in self._queues and k != ACTION:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
@@ -631,7 +636,7 @@ class VLAFlowMatching(nn.Module):
└──────────────────────────────┘
"""
def __init__(self, config):
def __init__(self, config: SmolVLAConfig):
super().__init__()
self.config = config