fix(DiffusionPolicy): Fix bug where training without image features would crash with exception, fix environment state docs (#1617)

* Fix bug in diffusion config validation when not using image features

* Fix DiffusionPolicy docstring about shape of env state
This commit is contained in:
Abhay Deshpande
2025-07-29 04:40:16 -07:00
committed by GitHub
parent c14ab9e97b
commit 5695432142
2 changed files with 9 additions and 8 deletions

View File

@@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig):
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
if len(self.image_features) > 0:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:

View File

@@ -288,7 +288,7 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"observation.environment_state": (B, n_obs_steps, environment_dim)
}
"""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@@ -315,7 +315,7 @@ class DiffusionModel(nn.Module):
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
AND/OR
"observation.environment_state": (B, environment_dim)
"observation.environment_state": (B, n_obs_steps, environment_dim)
"action": (B, horizon, action_dim)
"action_is_pad": (B, horizon)