From 615adfc48d60a8ecb9e1891c773405268770e414 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 28 Jul 2025 11:44:22 +0200 Subject: [PATCH] smolfix(vla): typing and fix offline inference when action in the batch (#1597) --- src/lerobot/policies/smolvla/modeling_smolvla.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index a31e1b078..d2f78068c 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -384,8 +384,13 @@ class SmolVLAPolicy(PreTrainedPolicy): return self.parameters() def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + # TODO: Check if this for loop is needed. + # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch + # In the case of offline inference, we have the action in the batch + # that why without the k != ACTION check, it will raise an error because we are trying to stack + # on an empty container. for k in batch: - if k in self._queues: + if k in self._queues and k != ACTION: batch[k] = torch.stack(list(self._queues[k]), dim=1) images, img_masks = self.prepare_images(batch) @@ -631,7 +636,7 @@ class VLAFlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config): + def __init__(self, config: SmolVLAConfig): super().__init__() self.config = config