ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions

View File

@@ -1,11 +1,10 @@
import copy
from typing import Dict, Optional, Tuple, Union
import timm
import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
@@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step.
relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
@@ -34,7 +32,6 @@ class RgbEncoder(nn.Module):
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x):
# TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))