fix(policies): remove action from batch for offline evaluation (#1609)

* fix(policies): remove action from batch for offline evaluation in diffusion, tdmpc, and vqbet policies

* style(diffusion): correct comment capitalization for clarity in modeling_diffusion.py
This commit is contained in:
Adil Zouitine
2025-07-28 13:10:34 +02:00
committed by GitHub
parent 664e069c3f
commit c3d5e494c0
3 changed files with 15 additions and 3 deletions

View File

@@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy):
"horizon" may not the best name to describe what the variable actually means, because this period is "horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
""" """
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key. # NOTE: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0: if len(self._queues[ACTION]) == 0:

View File

@@ -143,7 +143,12 @@ class TDMPCPolicy(PreTrainedPolicy):
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
if self.config.image_features: if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]

View File

@@ -139,11 +139,14 @@ class VQBeTPolicy(PreTrainedPolicy):
environment. It works by managing the actions in a queue and only calling `select_actions` when the environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty. queue is empty.
""" """
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if not self.vqbet.action_head.vqvae_model.discretized.item(): if not self.vqbet.action_head.vqvae_model.discretized.item():