diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 279a15672..cb4636c64 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -432,7 +432,10 @@ class DiffusionRgbEncoder(nn.Module): image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] # Note: we have a check in the config class to make sure all images have the same shape. image_key = image_keys[0] - dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) + if config.crop_shape is None: + dummy_input = torch.zeros(size=(1, *config.input_shapes[image_key])) + else: + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:])