ready for review

This commit is contained in:
Alexander Soare
2024-03-21 10:18:50 +00:00
parent d323993569
commit acf1174447
12 changed files with 282 additions and 85 deletions

View File

@@ -1,3 +1,44 @@
"""Code from the original diffusion policy project.
Notes on how to load a checkpoint from the original repository:
In the original repository, run the eval and use a breakpoint to extract the policy weights.
```
torch.save(policy.state_dict(), "weights.pt")
```
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
```
loaded = torch.load("weights.pt")
aligned = {}
their_prefix = "obs_encoder.obs_nets.image.backbone"
our_prefix = "obs_encoder.key_model_map.image.backbone"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.pool"
our_prefix = "obs_encoder.key_model_map.image.pool"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
their_prefix = "obs_encoder.obs_nets.image.nets.3"
our_prefix = "obs_encoder.key_model_map.image.out"
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
# Note: here you are loading into the ema model.
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
assert all('_dummy_variable' in k for k in missing_keys)
assert len(unexpected_keys) == 0
```
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
"""
from typing import Dict
import torch

View File

@@ -1,11 +1,10 @@
import copy
from typing import Dict, Optional, Tuple, Union
import timm
import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
@@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step.
relu: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
@@ -34,7 +32,6 @@ class RgbEncoder(nn.Module):
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x):
# TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))

View File

@@ -5,7 +5,6 @@ import time
import hydra
import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@@ -21,6 +20,7 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
@@ -71,8 +71,13 @@ class DiffusionPolicy(AbstractPolicy):
self.diffusion.cuda()
self.ema_diffusion = None
if self.cfg.ema.enable:
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = hydra.utils.instantiate(
cfg_ema,
model=self.ema_diffusion,
)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
@@ -175,8 +180,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.cfg.ema.enable:
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),