backup wip
This commit is contained in:
@@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
|
||||
"""
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
rele: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
||||
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
|
||||
# TODO(now): make nonlinearity optional
|
||||
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
@@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
@@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||
features.append(feature)
|
||||
|
||||
# concatenate all features
|
||||
|
||||
Reference in New Issue
Block a user