backup wip

This commit is contained in:
Alexander Soare
2024-03-20 15:01:27 +00:00
parent 32e3f71dd1
commit d323993569
7 changed files with 71 additions and 81 deletions

View File

@@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
@@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x):
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
# TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
class MultiImageObsEncoder(ModuleAttrMixin):
@@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
else:
# run each rgb obs to independent models
@@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
# concatenate all features

View File

@@ -1,9 +1,11 @@
import copy
import logging
import time
import hydra
import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
@@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
@@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.ema = None
if self.cfg.use_ema:
self.ema = hydra.utils.instantiate(
cfg_ema,
model=copy.deepcopy(self.diffusion),
)
self.ema_diffusion = None
if self.cfg.ema.enable:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
@@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
@torch.no_grad()
def select_actions(self, observation, step_count):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
# TODO(rcadene): remove unused step_count
del step_count
@@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
action = out["action"]
return action
@@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
if self.cfg.ema.enable:
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
info = {
"loss": loss.item(),
@@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0