Merge branch 'user/alexander-soare/train_pusht' into user/alexander-soare/multistep_policy_and_serial_env

This commit is contained in:
Alexander Soare
2024-03-14 16:06:21 +00:00
5 changed files with 241 additions and 19 deletions

View File

@@ -1,15 +1,37 @@
import copy
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import timm
import torch
import torch.nn as nn
import torchvision
from robomimic.models.base_nets import SpatialSoftmax
from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer
from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin
from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules
class RgbEncoder(nn.Module):
"""Following `VisualCore` from Robomimic 0.2.0."""
def __init__(self, input_shape, model_name="resnet18", pretrained=False, num_keypoints=32):
"""
resnet_name: a timm model name.
pretrained: whether to use timm pretrained weights.
num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image).
"""
super().__init__()
self.backbone = timm.create_model(model_name, pretrained, num_classes=0, global_pool="")
# Figure out the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:])
self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints)
def forward(self, x):
return torch.flatten(self.pool(self.backbone(x)), start_dim=1)
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
@@ -24,7 +46,7 @@ class MultiImageObsEncoder(ModuleAttrMixin):
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
norm_mean_std: Optional[tuple[float, float]] = None,
):
"""
Assumes rgb input: B,C,H,W
@@ -98,10 +120,9 @@ class MultiImageObsEncoder(ModuleAttrMixin):
this_normalizer = torchvision.transforms.CenterCrop(size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
# TODO(rcadene): move normalizer to dataset and env
if norm_mean_std is not None:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
mean=norm_mean_std[0], std=norm_mean_std[1]
)
this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer)

View File

@@ -7,7 +7,7 @@ import torch
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder
class DiffusionPolicy(AbstractPolicy):
@@ -38,7 +38,7 @@ class DiffusionPolicy(AbstractPolicy):
self.cfg = cfg
noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler)
rgb_model = hydra.utils.instantiate(cfg_rgb_model)
rgb_model = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg_rgb_model)
obs_encoder = MultiImageObsEncoder(
rgb_model=rgb_model,
**cfg_obs_encoder,

View File

@@ -42,8 +42,8 @@ policy:
num_inference_steps: 100
obs_as_global_cond: ${obs_as_global_cond}
# crop_shape: null
diffusion_step_embed_dim: 256 # before 128
down_dims: [256, 512, 1024] # before [512, 1024, 2048]
diffusion_step_embed_dim: 128
down_dims: [512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
@@ -81,12 +81,12 @@ obs_encoder:
# random_crop: True
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs)
rgb_model:
_target_: lerobot.common.policies.diffusion.pytorch_utils.get_resnet
name: resnet18
weights: null
model_name: resnet18
pretrained: false
num_keypoints: 32
ema:
_target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel
@@ -109,13 +109,13 @@ training:
debug: False
resume: True
# optimization
# lr_scheduler: cosine
# lr_warmup_steps: 500
num_epochs: 8000
lr_scheduler: cosine
lr_warmup_steps: 500
num_epochs: 500
# gradient_accumulate_every: 1
# EMA destroys performance when used with BatchNorm
# replace BatchNorm with GroupNorm.
# use_ema: True
use_ema: True
freeze_encoder: False
# training loop control
# in epochs