From c3d5e494c0b530368332c0b0eb114c32bc3b8f2c Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 28 Jul 2025 13:10:34 +0200 Subject: [PATCH] fix(policies): remove action from batch for offline evaluation (#1609) * fix(policies): remove action from batch for offline evaluation in diffusion, tdmpc, and vqbet policies * style(diffusion): correct comment capitalization for clarity in modeling_diffusion.py --- src/lerobot/policies/diffusion/modeling_diffusion.py | 6 +++++- src/lerobot/policies/tdmpc/modeling_tdmpc.py | 5 +++++ src/lerobot/policies/vqbet/modeling_vqbet.py | 7 +++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 24b27396..941a3acb 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy): "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) + batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - # Note: It's important that this happens after stacking the images into a single key. + # NOTE: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) if len(self._queues[ACTION]) == 0: diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 664fe863..7ba88e5e 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -143,7 +143,12 @@ class TDMPCPolicy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) + batch = self.normalize_inputs(batch) + if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index b271298a..feb65bb4 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -139,11 +139,14 @@ class VQBeTPolicy(PreTrainedPolicy): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + # NOTE: It's important that this happens after stacking the images into a single key. batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) if not self.vqbet.action_head.vqvae_model.discretized.item():