fix(DiffusionPolicy): Fix bug where training without image features would crash with exception, fix environment state docs (#1617)
* Fix bug in diffusion config validation when not using image features * Fix DiffusionPolicy docstring about shape of env state
This commit is contained in:
@@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check that all input images have the same shape.
|
# Check that all input images have the same shape.
|
||||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
if len(self.image_features) > 0:
|
||||||
for key, image_ft in self.image_features.items():
|
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||||
if image_ft.shape != first_image_ft.shape:
|
for key, image_ft in self.image_features.items():
|
||||||
raise ValueError(
|
if image_ft.shape != first_image_ft.shape:
|
||||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
raise ValueError(
|
||||||
)
|
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def observation_delta_indices(self) -> list:
|
def observation_delta_indices(self) -> list:
|
||||||
|
|||||||
@@ -288,7 +288,7 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
AND/OR
|
AND/OR
|
||||||
"observation.environment_state": (B, environment_dim)
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
@@ -315,7 +315,7 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||||
AND/OR
|
AND/OR
|
||||||
"observation.environment_state": (B, environment_dim)
|
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||||
|
|
||||||
"action": (B, horizon, action_dim)
|
"action": (B, horizon, action_dim)
|
||||||
"action_is_pad": (B, horizon)
|
"action_is_pad": (B, horizon)
|
||||||
|
|||||||
Reference in New Issue
Block a user