Add support for multi cam to VQ-BeT

This commit is contained in:
Remi Cadene
2024-07-16 16:23:17 +02:00
parent df23672bcd
commit b2aae48102

View File

@@ -733,12 +733,18 @@ class VQBeTRgbEncoder(nn.Module):
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`. # height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1 dummy_input_per_cam = {}
image_key = image_keys[0] for image_key in image_keys:
dummy_input_h_w = ( dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
) )
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
dummy_input_per_cam[image_key] = dummy_input
dummy_shape_per_camera = {k: dummy_input_per_cam[k].shape for k in image_keys}
if not all(dummy_shape_per_camera[k] == dummy_input.shape for k in image_keys):
raise NotImplementedError(f"At least one camera has a different shape: {dummy_shape_per_camera}")
with torch.inference_mode(): with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input) dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:]) feature_map_shape = tuple(dummy_feature_map.shape[1:])