Compare commits
1 Commits
main
...
user/fraca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63f7144080 |
@@ -142,6 +142,37 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad
|
||||
def predict_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations.
|
||||
|
||||
This method returns the raw chunk of actions predicted by the model without
|
||||
any queue management or action consumption logic.
|
||||
|
||||
Args:
|
||||
batch: A dictionary of observation tensors.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size, chunk_size, action_dim) containing predicted actions.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
# If we are using temporal ensembling
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
# Standard action prediction
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
Reference in New Issue
Block a user