diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index ce2de7052..54569434a 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig): ) # Check that all input images have the same shape. - first_image_key, first_image_ft = next(iter(self.image_features.items())) - for key, image_ft in self.image_features.items(): - if image_ft.shape != first_image_ft.shape: - raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." - ) + if len(self.image_features) > 0: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) @property def observation_delta_indices(self) -> list: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 941a3acb5..85d4d5981 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -288,7 +288,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -315,7 +315,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon)