ready for review
This commit is contained in:
@@ -192,7 +192,7 @@ class AlohaEnv(AbstractEnv):
|
||||
{
|
||||
"observation": TensorDict(obs, batch_size=[]),
|
||||
"reward": torch.tensor([reward], dtype=torch.float32),
|
||||
# succes and done are true when coverage > self.success_threshold in env
|
||||
# success and done are true when coverage > self.success_threshold in env
|
||||
"done": torch.tensor([done], dtype=torch.bool),
|
||||
"success": torch.tensor([success], dtype=torch.bool),
|
||||
},
|
||||
|
||||
@@ -62,27 +62,3 @@ def make_env(cfg, transform=None):
|
||||
{"seed": env_seed} for env_seed in range(cfg.seed, cfg.seed + cfg.rollout_batch_size)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# def make_env(env_name, frame_skip, device, is_test=False):
|
||||
# env = GymEnv(
|
||||
# env_name,
|
||||
# frame_skip=frame_skip,
|
||||
# from_pixels=True,
|
||||
# pixels_only=False,
|
||||
# device=device,
|
||||
# )
|
||||
# env = TransformedEnv(env)
|
||||
# env.append_transform(NoopResetEnv(noops=30, random=True))
|
||||
# if not is_test:
|
||||
# env.append_transform(EndOfLifeTransform())
|
||||
# env.append_transform(RewardClipping(-1, 1))
|
||||
# env.append_transform(ToTensorImage())
|
||||
# env.append_transform(GrayScale())
|
||||
# env.append_transform(Resize(84, 84))
|
||||
# env.append_transform(CatFrames(N=4, dim=-3))
|
||||
# env.append_transform(RewardSum())
|
||||
# env.append_transform(StepCounter(max_steps=4500))
|
||||
# env.append_transform(DoubleToFloat())
|
||||
# env.append_transform(VecNorm(in_keys=["pixels"]))
|
||||
# return env
|
||||
|
||||
@@ -3,6 +3,8 @@ import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
from torchrl.data.tensor_specs import (
|
||||
@@ -59,12 +61,30 @@ class PushtEnv(AbstractEnv):
|
||||
|
||||
self._env = PushTImageEnv(render_size=self.image_size)
|
||||
|
||||
def render(self, mode="rgb_array", width=384, height=384):
|
||||
def render(self, mode="rgb_array", width=96, height=96, with_marker=True):
|
||||
"""
|
||||
with_marker adds a cursor showing the targeted action for the controller.
|
||||
"""
|
||||
if width != height:
|
||||
raise NotImplementedError()
|
||||
tmp = self._env.render_size
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode)
|
||||
if width != self._env.render_size:
|
||||
self._env.render_cache = None
|
||||
self._env.render_size = width
|
||||
out = self._env.render(mode).copy()
|
||||
if with_marker and self._env.latest_action is not None:
|
||||
action = np.array(self._env.latest_action)
|
||||
coord = (action / 512 * self._env.render_size).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self._env.render_size)
|
||||
thickness = int(1 / 96 * self._env.render_size)
|
||||
cv2.drawMarker(
|
||||
out,
|
||||
coord,
|
||||
color=(255, 0, 0),
|
||||
markerType=cv2.MARKER_CROSS,
|
||||
markerSize=marker_size,
|
||||
thickness=thickness,
|
||||
)
|
||||
self._env.render_size = tmp
|
||||
return out
|
||||
|
||||
|
||||
@@ -27,20 +27,6 @@ class PushTImageEnv(PushTEnv):
|
||||
img_obs = np.moveaxis(img, -1, 0)
|
||||
obs = {"image": img_obs, "agent_pos": agent_pos}
|
||||
|
||||
# draw action
|
||||
if self.latest_action is not None:
|
||||
action = np.array(self.latest_action)
|
||||
coord = (action / 512 * 96).astype(np.int32)
|
||||
marker_size = int(8 / 96 * self.render_size)
|
||||
thickness = int(1 / 96 * self.render_size)
|
||||
# cv2.drawMarker(
|
||||
# img,
|
||||
# coord,
|
||||
# color=(255, 0, 0),
|
||||
# markerType=cv2.MARKER_CROSS,
|
||||
# markerSize=marker_size,
|
||||
# thickness=thickness,
|
||||
# )
|
||||
self.render_cache = img
|
||||
|
||||
return obs
|
||||
|
||||
@@ -1,3 +1,44 @@
|
||||
"""Code from the original diffusion policy project.
|
||||
|
||||
Notes on how to load a checkpoint from the original repository:
|
||||
|
||||
In the original repository, run the eval and use a breakpoint to extract the policy weights.
|
||||
|
||||
```
|
||||
torch.save(policy.state_dict(), "weights.pt")
|
||||
```
|
||||
|
||||
In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights:
|
||||
|
||||
```
|
||||
loaded = torch.load("weights.pt")
|
||||
aligned = {}
|
||||
their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
our_prefix = "obs_encoder.key_model_map.image.backbone"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
their_prefix = "obs_encoder.obs_nets.image.pool"
|
||||
our_prefix = "obs_encoder.key_model_map.image.pool"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
their_prefix = "obs_encoder.obs_nets.image.nets.3"
|
||||
our_prefix = "obs_encoder.key_model_map.image.out"
|
||||
aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
|
||||
# Note: here you are loading into the ema model.
|
||||
missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False)
|
||||
assert all('_dummy_variable' in k for k in missing_keys)
|
||||
assert len(unexpected_keys) == 0
|
||||
```
|
||||
|
||||
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
||||
|
||||
```
|
||||
policy.save("weights.pt")
|
||||
```
|
||||
|
||||
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import copy
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
|
||||
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
|
||||
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
|
||||
@@ -15,17 +14,16 @@ from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
|
||||
class RgbEncoder(nn.Module):
|
||||
"""Following `VisualCore` from Robomimic 0.2.0."""
|
||||
|
||||
def __init__(self, input_shape, model_name="resnet18", pretrained=False, relu=True, num_keypoints=32):
|
||||
def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32):
|
||||
"""
|
||||
input_shape: channel-first input shape (C, H, W)
|
||||
resnet_name: a timm model name.
|
||||
pretrained: whether to use timm pretrained weights.
|
||||
rele: whether to use relu as a final step.
|
||||
relu: whether to use relu as a final step.
|
||||
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
|
||||
"""
|
||||
super().__init__()
|
||||
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
|
||||
# self.backbone = ResNet18Conv(input_channel=input_shape[0])
|
||||
self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained)
|
||||
# Figure out the feature map shape.
|
||||
with torch.inference_mode():
|
||||
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
|
||||
@@ -34,7 +32,6 @@ class RgbEncoder(nn.Module):
|
||||
self.relu = nn.ReLU() if relu else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
# TODO(now): make nonlinearity optional
|
||||
return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1)))
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import time
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from lerobot.common.ema import update_ema_parameters
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
@@ -21,6 +20,7 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
cfg_rgb_model,
|
||||
cfg_obs_encoder,
|
||||
cfg_optimizer,
|
||||
cfg_ema,
|
||||
shape_meta: dict,
|
||||
horizon,
|
||||
n_action_steps,
|
||||
@@ -71,8 +71,13 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
self.diffusion.cuda()
|
||||
|
||||
self.ema_diffusion = None
|
||||
if self.cfg.ema.enable:
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
self.ema = hydra.utils.instantiate(
|
||||
cfg_ema,
|
||||
model=self.ema_diffusion,
|
||||
)
|
||||
|
||||
self.optimizer = hydra.utils.instantiate(
|
||||
cfg_optimizer,
|
||||
@@ -175,8 +180,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
self.optimizer.zero_grad()
|
||||
self.lr_scheduler.step()
|
||||
|
||||
if self.cfg.ema.enable:
|
||||
update_ema_parameters(self.ema_diffusion, self.diffusion, self.cfg.ema.rate)
|
||||
if self.ema is not None:
|
||||
self.ema.step(self.diffusion)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
|
||||
@@ -16,6 +16,7 @@ def make_policy(cfg):
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
@@ -39,23 +40,4 @@ def make_policy(cfg):
|
||||
raise NotImplementedError()
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
# import torch
|
||||
# loaded = torch.load('/home/alexander/Downloads/dp.pth')
|
||||
# aligned = {}
|
||||
|
||||
# their_prefix = "obs_encoder.obs_nets.image.backbone"
|
||||
# our_prefix = "obs_encoder.key_model_map.image.backbone"
|
||||
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
# their_prefix = "obs_encoder.obs_nets.image.pool"
|
||||
# our_prefix = "obs_encoder.key_model_map.image.pool"
|
||||
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
# their_prefix = "obs_encoder.obs_nets.image.nets.3"
|
||||
# our_prefix = "obs_encoder.key_model_map.image.out"
|
||||
# aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)})
|
||||
|
||||
# aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')})
|
||||
# missing_keys, unexpected_keys = policy.diffusion.load_state_dict(aligned, strict=False)
|
||||
# assert all('_dummy_variable' in k for k in missing_keys)
|
||||
# assert len(unexpected_keys) == 0
|
||||
|
||||
return policy
|
||||
|
||||
@@ -12,6 +12,7 @@ shape_meta:
|
||||
action:
|
||||
shape: [2]
|
||||
|
||||
seed: 100000
|
||||
horizon: 16
|
||||
n_obs_steps: 2
|
||||
n_action_steps: 8
|
||||
@@ -26,7 +27,7 @@ eval_freq: 5000
|
||||
save_freq: 5000
|
||||
log_freq: 250
|
||||
|
||||
offline_steps: 50000
|
||||
offline_steps: 200000
|
||||
online_steps: 0
|
||||
|
||||
offline_prioritized_sampler: true
|
||||
@@ -58,9 +59,7 @@ policy:
|
||||
balanced_sampling: false
|
||||
utd: 1
|
||||
offline_steps: ${offline_steps}
|
||||
ema:
|
||||
enable: true
|
||||
rate: 0.999
|
||||
use_ema: true
|
||||
lr_scheduler: cosine
|
||||
lr_warmup_steps: 500
|
||||
grad_clip_norm: 10
|
||||
@@ -86,11 +85,18 @@ obs_encoder:
|
||||
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
|
||||
|
||||
rgb_model:
|
||||
model_name: resnet18
|
||||
pretrained: false
|
||||
num_keypoints: 32
|
||||
relu: true
|
||||
|
||||
ema:
|
||||
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
|
||||
update_after_step: 0
|
||||
inv_gamma: 1.0
|
||||
power: 0.75
|
||||
min_value: 0.0
|
||||
max_value: 0.9999
|
||||
|
||||
optimizer:
|
||||
_target_: torch.optim.AdamW
|
||||
lr: 1.0e-4
|
||||
|
||||
@@ -50,6 +50,7 @@ def eval_policy(
|
||||
|
||||
def maybe_render_frame(env: EnvBase, _):
|
||||
if save_video or (return_first_video and i == 0): # noqa: B023
|
||||
# TODO now: generalize kwarg or maybe just remove it
|
||||
ep_frames.append(env.render()) # noqa: B023
|
||||
|
||||
with torch.inference_mode():
|
||||
|
||||
Reference in New Issue
Block a user