fix diffusion

This commit is contained in:
Remi Cadene
2024-05-29 09:50:44 +00:00
parent b47c07fbeb
commit ad0033ae12

View File

@@ -432,7 +432,10 @@ class DiffusionRgbEncoder(nn.Module):
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
if config.crop_shape is None:
dummy_input = torch.zeros(size=(1, *config.input_shapes[image_key]))
else:
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])