backup wip
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
@@ -34,14 +33,14 @@ class PushTImageEnv(PushTEnv):
|
||||
coord = (action / 512 * 96).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self.render_size)
|
||||
thickness = int(1 / 96 * self.render_size)
|
||||
cv2.drawMarker(
|
||||
img,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
# cv2.drawMarker(
|
||||
# img,
|
||||
# coord,
|
||||
# color=(255, 0, 0),
|
||||
# markerType=cv2.MARKER_CROSS,
|
||||
# markerSize=marker_size,
|
||||
# thickness=thickness,
|
||||
# )
|
||||
self.render_cache = img
|
||||
|
||||
return obs
|
||||
|
||||
@@ -15,11 +15,12 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
|
||||
"""
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
rele: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -30,9 +31,11 @@ class RgbEncoder(nn.Module):
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
|
||||
self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2)
|
||||
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))
|
||||
# TODO(now): make nonlinearity optional
|
||||
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
@@ -182,7 +185,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
@@ -195,7 +197,6 @@ class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
# feature = torch.nn.functional.relu(feature) # TODO: make optional
|
||||
features.append(feature)
|
||||
|
||||
# concatenate all features
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from lerobot.common.ema import update_ema_parameters
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
@@ -19,7 +21,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
@@ -42,7 +43,6 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
if cfg_obs_encoder.crop_shape is not None:
|
||||
rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape
|
||||
rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model)
|
||||
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
|
||||
obs_encoder = MultiImageObsEncoder(
|
||||
rgb_model=rgb_model,
|
||||
**cfg_obs_encoder,
|
||||
@@ -70,12 +70,9 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema = hydra.utils.instantiate(
|
||||
cfg_ema,
|
||||
model=copy.deepcopy(self.diffusion),
|
||||
)
|
||||
self.ema_diffusion = None
|
||||
if self.cfg.ema.enable:
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
|
||||
self.optimizer = hydra.utils.instantiate(
|
||||
cfg_optimizer,
|
||||
@@ -98,6 +95,9 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
"""
|
||||
Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights.
|
||||
"""
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
@@ -105,7 +105,10 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
"image": observation["image"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
if self.training:
|
||||
out = self.diffusion.predict_action(obs_dict)
|
||||
else:
|
||||
out = self.ema_diffusion.predict_action(obs_dict)
|
||||
action = out["action"]
|
||||
return action
|
||||
|
||||
@@ -172,8 +175,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
if self.cfg.ema.enable:
|
||||
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
@@ -195,4 +198,10 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
logging.warning(
|
||||
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
|
||||
)
|
||||
assert len(unexpected_keys) == 0
|
||||
|
||||
@@ -16,7 +16,6 @@ def make_policy(cfg):
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
@@ -41,7 +40,7 @@ def make_policy(cfg):
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
# import torch
|
||||
# loaded = torch.load('/home/alexander/Downloads/dp_ema.pth')
|
||||
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
|
||||
# aligned = {}
|
||||
|
||||
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
|
||||
Reference in New Issue
Block a user