forked from tangger/lerobot
fix diffusion
This commit is contained in:
@@ -432,7 +432,10 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
image_key = image_keys[0]
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||
if config.crop_shape is None:
|
||||
dummy_input = torch.zeros(size=(1, *config.input_shapes[image_key]))
|
||||
else:
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *config.crop_shape))
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
|
||||
Reference in New Issue
Block a user