Add support for multi cam to VQ-BeT

This commit is contained in:
Remi Cadene
2024-07-16 16:23:17 +02:00
parent df23672bcd
commit b2aae48102

View File

@@ -733,12 +733,18 @@ class VQBeTRgbEncoder(nn.Module):
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
dummy_input_per_cam = {}
for image_key in image_keys:
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
dummy_input_per_cam[image_key] = dummy_input
dummy_shape_per_camera = {k: dummy_input_per_cam[k].shape for k in image_keys}
if not all(dummy_shape_per_camera[k] == dummy_input.shape for k in image_keys):
raise NotImplementedError(f"At least one camera has a different shape: {dummy_shape_per_camera}")
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])