fix(DiffusionPolicy): Fix bug where training without image features would crash with exception, fix environment state docs (#1617)

* Fix bug in diffusion config validation when not using image features

* Fix DiffusionPolicy docstring about shape of env state
This commit is contained in:
Abhay Deshpande
2025-07-29 04:40:16 -07:00
committed by GitHub
parent c14ab9e97b
commit 5695432142
2 changed files with 9 additions and 8 deletions

View File

@@ -217,6 +217,7 @@ class DiffusionConfig(PreTrainedConfig):
) )
# Check that all input images have the same shape. # Check that all input images have the same shape.
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items())) first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items(): for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape: if image_ft.shape != first_image_ft.shape:

View File

@@ -288,7 +288,7 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W) "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR AND/OR
"observation.environment_state": (B, environment_dim) "observation.environment_state": (B, n_obs_steps, environment_dim)
} }
""" """
batch_size, n_obs_steps = batch["observation.state"].shape[:2] batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@@ -315,7 +315,7 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W) "observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR AND/OR
"observation.environment_state": (B, environment_dim) "observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim) "action": (B, horizon, action_dim)
"action_is_pad": (B, horizon) "action_is_pad": (B, horizon)