Files
lerobot/lerobot/common/policies/diffusion/pytorch_utils.py
2024-03-10 16:33:03 +01:00

77 lines
2.2 KiB
Python

from typing import Callable, Dict
import torch
import torch.nn as nn
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = "cpu"
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to("cpu")
return resnet_model
def dict_apply(
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = {}
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def replace_submodules(
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule(".".join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
assert len(bn_list) == 0
return root_module