Integrate diffusion policy
This commit is contained in:
46
lerobot/common/policies/diffusion/pytorch_utils.py
Normal file
46
lerobot/common/policies/diffusion/pytorch_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def dict_apply(
|
||||
x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
result = {}
|
||||
for key, value in x.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = dict_apply(value, func)
|
||||
else:
|
||||
result[key] = func(value)
|
||||
return result
|
||||
|
||||
|
||||
def replace_submodules(
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
predicate: Return true if the module is to be replaced.
|
||||
func: Return new module to use.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
for *parent, k in bn_list:
|
||||
parent_module = root_module
|
||||
if len(parent) > 0:
|
||||
parent_module = root_module.get_submodule(".".join(parent))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)]
|
||||
assert len(bn_list) == 0
|
||||
return root_module
|
||||
Reference in New Issue
Block a user