smolfix(vla): typing and fix offline inference when action in the batch (#1597)
This commit is contained in:
@@ -384,8 +384,13 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
return self.parameters()
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
# TODO: Check if this for loop is needed.
|
||||
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
|
||||
# In the case of offline inference, we have the action in the batch
|
||||
# that why without the k != ACTION check, it will raise an error because we are trying to stack
|
||||
# on an empty container.
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
if k in self._queues and k != ACTION:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
@@ -631,7 +636,7 @@ class VLAFlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: SmolVLAConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user