forked from tangger/lerobot
Add support for multi cam to VQ-BeT
This commit is contained in:
@@ -733,12 +733,18 @@ class VQBeTRgbEncoder(nn.Module):
|
|||||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||||
# height and width from `config.input_shapes`.
|
# height and width from `config.input_shapes`.
|
||||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
assert len(image_keys) == 1
|
dummy_input_per_cam = {}
|
||||||
image_key = image_keys[0]
|
for image_key in image_keys:
|
||||||
dummy_input_h_w = (
|
dummy_input_h_w = (
|
||||||
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
|
||||||
)
|
)
|
||||||
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
|
||||||
|
dummy_input_per_cam[image_key] = dummy_input
|
||||||
|
|
||||||
|
dummy_shape_per_camera = {k: dummy_input_per_cam[k].shape for k in image_keys}
|
||||||
|
if not all(dummy_shape_per_camera[k] == dummy_input.shape for k in image_keys):
|
||||||
|
raise NotImplementedError(f"At least one camera has a different shape: {dummy_shape_per_camera}")
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
dummy_feature_map = self.backbone(dummy_input)
|
dummy_feature_map = self.backbone(dummy_input)
|
||||||
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
feature_map_shape = tuple(dummy_feature_map.shape[1:])
|
||||||
|
|||||||
Reference in New Issue
Block a user