From ad0033ae128ad1db80541d07db3578a140a1f3d0 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 29 May 2024 09:50:44 +0000 Subject: [PATCH] fix diffusion --- lerobot/common/policies/diffusion/modeling_diffusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 279a15672..cb4636c64 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -432,7 +432,10 @@ class DiffusionRgbEncoder(nn.Module): image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] # Note: we have a check in the config class to make sure all images have the same shape. image_key = image_keys[0] - dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) + if config.crop_shape is None: + dummy_input = torch.zeros(size=(1, *config.input_shapes[image_key])) + else: + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:])