Compare commits
1 Commits
pre-commit
...
user/fraca
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63f7144080 |
@@ -142,6 +142,37 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._action_queue.extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
|
@torch.no_grad
|
||||||
|
def predict_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
"""Predict a chunk of actions given environment observations.
|
||||||
|
|
||||||
|
This method returns the raw chunk of actions predicted by the model without
|
||||||
|
any queue management or action consumption logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: A dictionary of observation tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size, chunk_size, action_dim) containing predicted actions.
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
batch = self.normalize_inputs(batch)
|
||||||
|
if self.config.image_features:
|
||||||
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||||
|
|
||||||
|
# If we are using temporal ensembling
|
||||||
|
if self.config.temporal_ensemble_coeff is not None:
|
||||||
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
return actions
|
||||||
|
|
||||||
|
# Standard action prediction
|
||||||
|
actions = self.model(batch)[0]
|
||||||
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
return actions
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user