backup wip

This commit is contained in:
Alexander Soare
2024-03-20 15:01:27 +00:00
parent 32e3f71dd1
commit d323993569
7 changed files with 71 additions and 81 deletions

View File

@@ -1,4 +1,3 @@
import cv2
import numpy as np
from gym import spaces
@@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv):
coord = (action / 512 * 96).astype(np.int32)
marker_size = int(8 / 96 * self.render_size)
thickness = int(1 / 96 * self.render_size)
cv2.drawMarker(
img,
coord,
color=(255, 0, 0),
markerType=cv2.MARKER_CROSS,
markerSize=marker_size,
thickness=thickness,
)
# cv2.drawMarker(
# img,
# coord,
# color=(255, 0, 0),
# markerType=cv2.MARKER_CROSS,
# markerSize=marker_size,
# thickness=thickness,
# )
self.render_cache = img
return obs

View File

@@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
"""
input_shape: channel-first input shape (C, H, W)
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
rele: whether to use relu as a final step.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
@@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
self.relu = nn.ReLU() if relu else nn.Identity()
def forward(self, x):
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
# TODO(now): make nonlinearity optional
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
class MultiImageObsEncoder(ModuleAttrMixin):
@@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
else:
# run each rgb obs to independent models
@@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
# feature = torch.nn.functional.relu(feature) # TODO: make optional
features.append(feature)
# concatenate all features

View File

@@ -1,9 +1,11 @@
import copy
import logging
import time
import hydra
import torch
from lerobot.common.ema import update_ema_parameters
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
@@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
cfg_rgb_model,
cfg_obs_encoder,
cfg_optimizer,
cfg_ema,
shape_meta: dict,
horizon,
n_action_steps,
@@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
if cfg_obs_encoder.crop_shape is not None:
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,
@@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()
self.ema = None
if self.cfg.use_ema:
self.ema = hydra.utils.instantiate(
cfg_ema,
model=copy.deepcopy(self.diffusion),
)
self.ema_diffusion = None
if self.cfg.ema.enable:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.optimizer = hydra.utils.instantiate(
cfg_optimizer,
@@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
@torch.no_grad()
def select_actions(self, observation, step_count):
"""
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
"""
# TODO(rcadene): remove unused step_count
del step_count
@@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
"image": observation["image"],
"agent_pos": observation["state"],
}
out = self.diffusion.predict_action(obs_dict)
if self.training:
out = self.diffusion.predict_action(obs_dict)
else:
out = self.ema_diffusion.predict_action(obs_dict)
action = out["action"]
return action
@@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
if self.cfg.ema.enable:
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
info = {
"loss": loss.item(),
@@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
def load(self, fp):
d = torch.load(fp)
self.load_state_dict(d)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0

View File

@@ -16,7 +16,6 @@ def make_policy(cfg):
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
@@ -41,7 +40,7 @@ def make_policy(cfg):
policy.load(cfg.policy.pretrained_model_path)
# import torch
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
# aligned = {}
# their_prefix = "obs_encoder.obs_nets.image.backbone"