77 lines
2.2 KiB
Python
77 lines
2.2 KiB
Python
from typing import Callable, Dict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision
|
|
|
|
|
|
def get_resnet(name, weights=None, **kwargs):
|
|
"""
|
|
name: resnet18, resnet34, resnet50
|
|
weights: "IMAGENET1K_V1", "r3m"
|
|
"""
|
|
# load r3m weights
|
|
if (weights == "r3m") or (weights == "R3M"):
|
|
return get_r3m(name=name, **kwargs)
|
|
|
|
func = getattr(torchvision.models, name)
|
|
resnet = func(weights=weights, **kwargs)
|
|
resnet.fc = torch.nn.Identity()
|
|
return resnet
|
|
|
|
|
|
def get_r3m(name, **kwargs):
|
|
"""
|
|
name: resnet18, resnet34, resnet50
|
|
"""
|
|
import r3m
|
|
|
|
r3m.device = "cpu"
|
|
model = r3m.load_r3m(name)
|
|
r3m_model = model.module
|
|
resnet_model = r3m_model.convnet
|
|
resnet_model = resnet_model.to("cpu")
|
|
return resnet_model
|
|
|
|
|
|
def dict_apply(
|
|
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
|
|
) -> Dict[str, torch.Tensor]:
|
|
result = {}
|
|
for key, value in x.items():
|
|
if isinstance(value, dict):
|
|
result[key] = dict_apply(value, func)
|
|
else:
|
|
result[key] = func(value)
|
|
return result
|
|
|
|
|
|
def replace_submodules(
|
|
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
|
) -> nn.Module:
|
|
"""
|
|
predicate: Return true if the module is to be replaced.
|
|
func: Return new module to use.
|
|
"""
|
|
if predicate(root_module):
|
|
return func(root_module)
|
|
|
|
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
|
for *parent, k in bn_list:
|
|
parent_module = root_module
|
|
if len(parent) > 0:
|
|
parent_module = root_module.get_submodule(".".join(parent))
|
|
if isinstance(parent_module, nn.Sequential):
|
|
src_module = parent_module[int(k)]
|
|
else:
|
|
src_module = getattr(parent_module, k)
|
|
tgt_module = func(src_module)
|
|
if isinstance(parent_module, nn.Sequential):
|
|
parent_module[int(k)] = tgt_module
|
|
else:
|
|
setattr(parent_module, k, tgt_module)
|
|
# verify that all BN are replaced
|
|
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
|
assert len(bn_list) == 0
|
|
return root_module
|