Use PytorchModelHubMixin to save models as safetensors (#125)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:17:18 +01:00
committed by GitHub
parent 01d5490d44
commit a4891095e4
18 changed files with 556 additions and 527 deletions

View File

@@ -9,7 +9,6 @@ TODO(alexander-soare):
"""
import copy
import logging
import math
from collections import deque
from typing import Callable
@@ -19,6 +18,7 @@ import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm
@@ -32,7 +32,7 @@ from lerobot.common.policies.utils import (
)
class DiffusionPolicy(nn.Module):
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
@@ -41,45 +41,50 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(
self, cfg: DiffusionConfig | None = None, lr_scheduler_num_training_steps: int = 0, dataset_stats=None
self,
config: DiffusionConfig | None = None,
dataset_stats=None,
):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
"""
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
if config is None:
config = DiffusionConfig()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = DiffusionModel(cfg)
self.diffusion = DiffusionModel(config)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
if self.config.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = DiffusionEMA(cfg, model=self.ema_diffusion)
self.ema = DiffusionEMA(config, model=self.ema_diffusion)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
"""
self._queues = {
"observation.image": deque(maxlen=self.cfg.n_obs_steps),
"observation.state": deque(maxlen=self.cfg.n_obs_steps),
"action": deque(maxlen=self.cfg.n_action_steps),
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
@torch.no_grad
@@ -138,46 +143,34 @@ class DiffusionPolicy(nn.Module):
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
logging.warning(
"DiffusionPolicy.load expected ema parameters in loaded state dict but none were found."
)
assert len(unexpected_keys) == 0
class DiffusionModel(nn.Module):
def __init__(self, cfg: DiffusionConfig):
def __init__(self, config: DiffusionConfig):
super().__init__()
self.cfg = cfg
self.config = config
self.rgb_encoder = DiffusionRgbEncoder(cfg)
self.rgb_encoder = DiffusionRgbEncoder(config)
self.unet = DiffusionConditionalUnet1d(
cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
config,
global_cond_dim=(config.output_shapes["action"][0] + self.rgb_encoder.feature_dim)
* config.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=cfg.num_train_timesteps,
beta_start=cfg.beta_start,
beta_end=cfg.beta_end,
beta_schedule=cfg.beta_schedule,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=cfg.clip_sample,
clip_sample_range=cfg.clip_sample_range,
prediction_type=cfg.prediction_type,
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
)
if cfg.num_inference_steps is None:
if config.num_inference_steps is None:
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
else:
self.num_inference_steps = cfg.num_inference_steps
self.num_inference_steps = config.num_inference_steps
# ========= inference ============
def conditional_sample(
@@ -188,7 +181,7 @@ class DiffusionModel(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.cfg.horizon, self.cfg.output_shapes["action"][0]),
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
dtype=dtype,
device=device,
generator=generator,
@@ -218,7 +211,7 @@ class DiffusionModel(nn.Module):
"""
assert set(batch).issuperset({"observation.state", "observation.image"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
assert n_obs_steps == self.cfg.n_obs_steps
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@@ -231,10 +224,10 @@ class DiffusionModel(nn.Module):
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.cfg.output_shapes["action"][0]]
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.cfg.n_action_steps
end = start + self.config.n_action_steps
actions = actions[:, start:end]
return actions
@@ -253,8 +246,8 @@ class DiffusionModel(nn.Module):
assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
horizon = batch["action"].shape[1]
assert horizon == self.cfg.horizon
assert n_obs_steps == self.cfg.n_obs_steps
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
# Extract image feature (first combine batch and sequence dims).
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
@@ -283,12 +276,12 @@ class DiffusionModel(nn.Module):
# Compute the loss.
# The target is either the original trajectory, or the noise.
if self.cfg.prediction_type == "epsilon":
if self.config.prediction_type == "epsilon":
target = eps
elif self.cfg.prediction_type == "sample":
elif self.config.prediction_type == "sample":
target = batch["action"]
else:
raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}")
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")
loss = F.mse_loss(pred, target, reduction="none")
@@ -306,29 +299,29 @@ class DiffusionRgbEncoder(nn.Module):
Includes the ability to normalize and crop the image first.
"""
def __init__(self, cfg: DiffusionConfig):
def __init__(self, config: DiffusionConfig):
super().__init__()
# Set up optional preprocessing.
if cfg.crop_shape is not None:
if config.crop_shape is not None:
self.do_crop = True
# Always use center crop for eval
self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape)
if cfg.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape)
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
if config.crop_is_random:
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
else:
self.maybe_random_crop = self.center_crop
else:
self.do_crop = False
# Set up backbone.
backbone_model = getattr(torchvision.models, cfg.vision_backbone)(
weights=cfg.pretrained_backbone_weights
backbone_model = getattr(torchvision.models, config.vision_backbone)(
weights=config.pretrained_backbone_weights
)
# Note: This assumes that the layer4 feature map is children()[-3]
# TODO(alexander-soare): Use a safer alternative.
self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2]))
if cfg.use_group_norm:
if cfg.pretrained_backbone_weights:
if config.use_group_norm:
if config.pretrained_backbone_weights:
raise ValueError(
"You can't replace BatchNorm in a pretrained model without ruining the weights!"
)
@@ -342,11 +335,11 @@ class DiffusionRgbEncoder(nn.Module):
# Use a dry run to get the feature map shape.
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *cfg.input_shapes["observation.image"]))).shape[1:]
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints)
self.feature_dim = cfg.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()
def forward(self, x: Tensor) -> Tensor:
@@ -442,34 +435,34 @@ class DiffusionConditionalUnet1d(nn.Module):
Note: this removes local conditioning as compared to the original diffusion policy code.
"""
def __init__(self, cfg: DiffusionConfig, global_cond_dim: int):
def __init__(self, config: DiffusionConfig, global_cond_dim: int):
super().__init__()
self.cfg = cfg
self.config = config
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim),
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim),
)
# The FiLM conditioning dimension.
cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim
cond_dim = config.diffusion_step_embed_dim + global_cond_dim
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(cfg.output_shapes["action"][0], cfg.down_dims[0])] + list(
zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True)
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
# Unet encoder.
common_res_block_kwargs = {
"cond_dim": cond_dim,
"kernel_size": cfg.kernel_size,
"n_groups": cfg.n_groups,
"use_film_scale_modulation": cfg.use_film_scale_modulation,
"kernel_size": config.kernel_size,
"n_groups": config.n_groups,
"use_film_scale_modulation": config.use_film_scale_modulation,
}
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
@@ -489,10 +482,10 @@ class DiffusionConditionalUnet1d(nn.Module):
self.mid_modules = nn.ModuleList(
[
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
),
]
)
@@ -514,8 +507,8 @@ class DiffusionConditionalUnet1d(nn.Module):
)
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:
@@ -626,13 +619,13 @@ class DiffusionEMA:
Exponential Moving Average of models weights
"""
def __init__(self, cfg: DiffusionConfig, model: nn.Module):
def __init__(self, config: DiffusionConfig, model: nn.Module):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models
you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999
at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999
at 10K steps, 0.9999 at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
@@ -643,11 +636,11 @@ class DiffusionEMA:
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = cfg.ema_update_after_step
self.inv_gamma = cfg.ema_inv_gamma
self.power = cfg.ema_power
self.min_alpha = cfg.ema_min_alpha
self.max_alpha = cfg.ema_max_alpha
self.update_after_step = config.ema_update_after_step
self.inv_gamma = config.ema_inv_gamma
self.power = config.ema_power
self.min_alpha = config.ema_min_alpha
self.max_alpha = config.ema_max_alpha
self.alpha = 0.0
self.optimization_step = 0