early training loss as expected
This commit is contained in:
@@ -1,15 +1,37 @@
|
||||
import copy
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
|
||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
|
||||
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
||||
"""
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
|
||||
# Figure out the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.flatten(self.pool(self.backbone(x)), start_dim=1)
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -101,7 +123,8 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
if imagenet_norm:
|
||||
# TODO(rcadene): move normalizer to dataset and env
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
mean=[127.5, 127.5, 127.5],
|
||||
std=[127.5, 127.5, 127.5],
|
||||
)
|
||||
|
||||
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)
|
||||
|
||||
Reference in New Issue
Block a user