From b2aae48102dc5901cb36b2f8baf2981d1cf98b8e Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 16 Jul 2024 16:23:17 +0200 Subject: [PATCH] Add support for multi cam to VQ-BeT --- .../common/policies/vqbet/modeling_vqbet.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 6fb9c5d8..8f1583dd 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -733,12 +733,18 @@ class VQBeTRgbEncoder(nn.Module): # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # height and width from `config.input_shapes`. image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] - assert len(image_keys) == 1 - image_key = image_keys[0] - dummy_input_h_w = ( - config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] - ) - dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) + dummy_input_per_cam = {} + for image_key in image_keys: + dummy_input_h_w = ( + config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] + ) + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) + dummy_input_per_cam[image_key] = dummy_input + + dummy_shape_per_camera = {k: dummy_input_per_cam[k].shape for k in image_keys} + if not all(dummy_shape_per_camera[k] == dummy_input.shape for k in image_keys): + raise NotImplementedError(f"At least one camera has a different shape: {dummy_shape_per_camera}") + with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:])