fix diffusion
This commit is contained in:
@@ -432,7 +432,10 @@ class DiffusionRgbEncoder(nn.Module):
|
|||||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||||
image_key = image_keys[0]
|
image_key = image_keys[0]
|
||||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
if config.crop_shape is None:
|
||||||
|
dummy_input = torch.zeros(size=(1, *config.input_shapes[image_key]))
|
||||||
|
else:
|
||||||
|
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dummy_feature_map = self.backbone(dummy_input)
|
dummy_feature_map = self.backbone(dummy_input)
|
||||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||||
|
|||||||
Reference in New Issue
Block a user