From 63f714408022199b8821e5327ebc8d9c4b607d61 Mon Sep 17 00:00:00 2001 From: Francesco Capuano Date: Wed, 23 Apr 2025 12:59:10 +0200 Subject: [PATCH] fix: sharing predicted chunk with user --- lerobot/common/policies/act/modeling_act.py | 31 +++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72d4df03..2623e165 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -142,6 +142,37 @@ class ACTPolicy(PreTrainedPolicy): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() + @torch.no_grad + def predict_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations. + + This method returns the raw chunk of actions predicted by the model without + any queue management or action consumption logic. + + Args: + batch: A dictionary of observation tensors. + + Returns: + A tensor of shape (batch_size, chunk_size, action_dim) containing predicted actions. + """ + self.eval() + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = [batch[key] for key in self.config.image_features] + + # If we are using temporal ensembling + if self.config.temporal_ensemble_coeff is not None: + actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) + actions = self.unnormalize_outputs({"action": actions})["action"] + return actions + + # Standard action prediction + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({"action": actions})["action"] + return actions + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch)