From 5695432142c44f787ab6432f44faa8126932bda5 Mon Sep 17 00:00:00 2001 From: Abhay Deshpande Date: Tue, 29 Jul 2025 04:40:16 -0700 Subject: [PATCH] fix(DiffusionPolicy): Fix bug where training without image features would crash with exception, fix environment state docs (#1617) * Fix bug in diffusion config validation when not using image features * Fix DiffusionPolicy docstring about shape of env state --- .../policies/diffusion/configuration_diffusion.py | 13 +++++++------ .../policies/diffusion/modeling_diffusion.py | 4 ++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index ce2de7052..54569434a 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig): ) # Check that all input images have the same shape. - first_image_key, first_image_ft = next(iter(self.image_features.items())) - for key, image_ft in self.image_features.items(): - if image_ft.shape != first_image_ft.shape: - raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." - ) + if len(self.image_features) > 0: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) @property def observation_delta_indices(self) -> list: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 941a3acb5..85d4d5981 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -288,7 +288,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -315,7 +315,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon)