Compare commits

..

1 Commits

Author SHA1 Message Date
Francesco Capuano
63f7144080 fix: sharing predicted chunk with user 2025-04-23 12:59:10 +02:00
3 changed files with 35 additions and 1 deletions

View File

@@ -142,6 +142,37 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def predict_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations.
This method returns the raw chunk of actions predicted by the model without
any queue management or action consumption logic.
Args:
batch: A dictionary of observation tensors.
Returns:
A tensor of shape (batch_size, chunk_size, action_dim) containing predicted actions.
"""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
# If we are using temporal ensembling
if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
# Standard action prediction
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)

View File

@@ -49,7 +49,7 @@ dependencies = [
"datasets>=2.19.0",
"deepdiff>=7.0.1",
"diffusers>=0.27.2",
"draccus==0.10.0",
"draccus>=0.10.0",
"einops>=0.8.0",
"flask>=3.0.3",
"gdown>=5.1.0",

View File

@@ -37,6 +37,7 @@ def test_diffuser_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
@@ -55,6 +56,7 @@ def test_vqbet_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
@@ -75,6 +77,7 @@ def test_cosine_decay_with_warmup_scheduler(optimizer):
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict