ready for review
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
import copy
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
|
||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
@@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
|
||||
def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
|
||||
"""
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
rele: whether to use relu as a final step.
|
||||
relu: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
|
||||
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
|
||||
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
|
||||
# Figure out the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
@@ -34,7 +32,6 @@ class RgbEncoder(nn.Module):
|
||||
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
# TODO(now): make nonlinearity optional
|
||||
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user