backup wip

This commit is contained in:
Alexander Soare
2024-03-20 15:01:27 +00:00
parent 32e3f71dd1
commit d323993569
7 changed files with 71 additions and 81 deletions

View File

@@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
@@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x):
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
# TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
class MultiImageObsEncoder(ModuleAttrMixin):
@@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
else:
# run each rgb obs to independent models
@@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
# concatenate all features