from typing import Callable, Dict import torch import torch.nn as nn def dict_apply( x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] ) -> Dict[str, torch.Tensor]: result = {} for key, value in x.items(): if isinstance(value, dict): result[key] = dict_apply(value, func) else: result[key] = func(value) return result def replace_submodules( root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] ) -> nn.Module: """ predicate: Return true if the module is to be replaced. func: Return new module to use. """ if predicate(root_module): return func(root_module) bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] for *parent, k in bn_list: parent_module = root_module if len(parent) > 0: parent_module = root_module.get_submodule(".".join(parent)) if isinstance(parent_module, nn.Sequential): src_module = parent_module[int(k)] else: src_module = getattr(parent_module, k) tgt_module = func(src_module) if isinstance(parent_module, nn.Sequential): parent_module[int(k)] = tgt_module else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] assert len(bn_list) == 0 return root_module