smolfix(vla): typing and fix offline inference when action in the batch (#1597)
This commit is contained in:
@@ -384,8 +384,13 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
|||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
|
# TODO: Check if this for loop is needed.
|
||||||
|
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
|
||||||
|
# In the case of offline inference, we have the action in the batch
|
||||||
|
# that why without the k != ACTION check, it will raise an error because we are trying to stack
|
||||||
|
# on an empty container.
|
||||||
for k in batch:
|
for k in batch:
|
||||||
if k in self._queues:
|
if k in self._queues and k != ACTION:
|
||||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||||
|
|
||||||
images, img_masks = self.prepare_images(batch)
|
images, img_masks = self.prepare_images(batch)
|
||||||
@@ -631,7 +636,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
└──────────────────────────────┘
|
└──────────────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config: SmolVLAConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user