Add support for multi cam to VQ-BeT
This commit is contained in:
@@ -733,12 +733,18 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.input_shapes`.
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
assert len(image_keys) == 1
|
||||
image_key = image_keys[0]
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
dummy_input_per_cam = {}
|
||||
for image_key in image_keys:
|
||||
dummy_input_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||
)
|
||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||
dummy_input_per_cam[image_key] = dummy_input
|
||||
|
||||
dummy_shape_per_camera = {k: dummy_input_per_cam[k].shape for k in image_keys}
|
||||
if not all(dummy_shape_per_camera[k] == dummy_input.shape for k in image_keys):
|
||||
raise NotImplementedError(f"At least one camera has a different shape: {dummy_shape_per_camera}")
|
||||
|
||||
with torch.inference_mode():
|
||||
dummy_feature_map = self.backbone(dummy_input)
|
||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||
|
||||
Reference in New Issue
Block a user