from typing import Callable, Dict import torch import torch.nn as nn import torchvision def get_resnet(name, weights=None, **kwargs): """ name: resnet18, resnet34, resnet50 weights: "IMAGENET1K_V1", "r3m" """ # load r3m weights if (weights == "r3m") or (weights == "R3M"): return get_r3m(name=name, **kwargs) func = getattr(torchvision.models, name) resnet = func(weights=weights, **kwargs) resnet.fc = torch.nn.Identity() return resnet def get_r3m(name, **kwargs): """ name: resnet18, resnet34, resnet50 """ import r3m r3m.device = "cpu" model = r3m.load_r3m(name) r3m_model = model.module resnet_model = r3m_model.convnet resnet_model = resnet_model.to("cpu") return resnet_model def dict_apply( x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] ) -> Dict[str, torch.Tensor]: result = {} for key, value in x.items(): if isinstance(value, dict): result[key] = dict_apply(value, func) else: result[key] = func(value) return result def replace_submodules( root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] ) -> nn.Module: """ predicate: Return true if the module is to be replaced. func: Return new module to use. """ if predicate(root_module): return func(root_module) bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] for *parent, k in bn_list: parent_module = root_module if len(parent) > 0: parent_module = root_module.get_submodule(".".join(parent)) if isinstance(parent_module, nn.Sequential): src_module = parent_module[int(k)] else: src_module = getattr(parent_module, k) tgt_module = func(src_module) if isinstance(parent_module, nn.Sequential): parent_module[int(k)] = tgt_module else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] assert len(bn_list) == 0 return root_module