diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 47378fdf..33565399 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -165,7 +165,9 @@ class DiffusionModel(nn.Module): num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) self.unet = DiffusionConditionalUnet1d( config, - global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim * num_images) + global_cond_dim=( + config.input_shapes["observation.state"][0] + self.rgb_encoder.feature_dim * num_images + ) * config.n_obs_steps, )