Compare commits

...

1 Commits

Author SHA1 Message Date
Francesco Capuano
63f7144080 fix: sharing predicted chunk with user 2025-04-23 12:59:10 +02:00

View File

@@ -142,6 +142,37 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def predict_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations.
This method returns the raw chunk of actions predicted by the model without
any queue management or action consumption logic.
Args:
batch: A dictionary of observation tensors.
Returns:
A tensor of shape (batch_size, chunk_size, action_dim) containing predicted actions.
"""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
# If we are using temporal ensembling
if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# Standard action prediction
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)