backup wip

This commit is contained in:
Alexander Soare
2024-04-11 17:51:35 +01:00
parent 91ff69d64c
commit 976a197f98
26 changed files with 661 additions and 2733 deletions

View File

@@ -1,286 +1,307 @@
import logging
from typing import Union
import math
import einops
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d
from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb
from torch import Tensor
logger = logging.getLogger(__name__)
class ConditionalResidualBlock1D(nn.Module):
class _SinusoidalPosEmb(nn.Module):
# TODO(now): consolidate?
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class _Conv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
nn.GroupNorm(n_groups, out_channels),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
class _ConditionalResidualBlock1D(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False
self,
in_channels: int,
out_channels: int,
cond_dim: int,
kernel_size: int = 3,
n_groups: int = 8,
# Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning
# FiLM just modulates bias).
film_scale_modulation: bool = False,
):
super().__init__()
self.blocks = nn.ModuleList(
[
Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups),
Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups),
]
)
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale:
cond_channels = out_channels * 2
self.cond_predict_scale = cond_predict_scale
self.film_scale_modulation = film_scale_modulation
self.out_channels = out_channels
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
Rearrange("batch t -> batch t 1"),
)
# make sure dimensions compatible
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
)
def forward(self, x, cond):
def forward(self, x: Tensor, cond: Tensor) -> Tensor:
"""
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
Args:
x: (B, in_channels, T)
cond: (B, cond_dim)
Returns:
(B, out_channels, T)
"""
out = self.conv1(x)
returns:
out : [ batch_size x out_channels x horizon ]
"""
out = self.blocks[0](x)
embed = self.cond_encoder(cond)
if self.cond_predict_scale:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
# Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1).
cond_embed = self.cond_encoder(cond).unsqueeze(-1)
if self.film_scale_modulation:
# Treat the embedding as a list of scales and biases.
scale = cond_embed[:, : self.out_channels]
bias = cond_embed[:, self.out_channels :]
out = scale * out + bias
else:
out = out + embed
out = self.blocks[1](out)
# Treat the embedding as biases.
out = out + cond_embed
out = self.conv2(out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Two types of conditioning can be applied:
- Global: Conditioning information that is aggregated over the whole observation window. This is
incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder.
- Local: Conditioning information for each timestep in the observation window. This is incorporated
by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and
outputs of the Unet's encoder/decoder.
"""
def __init__(
self,
input_dim,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=None,
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
input_dim: int,
local_cond_dim: int | None = None,
global_cond_dim: int | None = None,
diffusion_step_embed_dim: int = 256,
down_dims: int | None = None,
kernel_size: int = 3,
n_groups: int = 8,
film_scale_modulation: bool = False,
):
super().__init__()
if down_dims is None:
down_dims = [256, 512, 1024]
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(diffusion_step_embed_dim),
nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim),
)
cond_dim = dsed
# The FiLM conditioning dimension.
cond_dim = diffusion_step_embed_dim
if global_cond_dim is not None:
cond_dim += global_cond_dim
in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False))
local_cond_encoder = None
self.local_cond_down_encoder = None
self.local_cond_up_encoder = None
if local_cond_dim is not None:
_, dim_out = in_out[0]
dim_in = local_cond_dim
local_cond_encoder = nn.ModuleList(
[
# down encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
# up encoder
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
]
# Encoder for the local conditioning. The output gets added to the Unet encoder input.
self.local_cond_down_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
# Encoder for the local conditioning. The output gets added to the Unet encoder output.
self.local_cond_up_encoder = _ConditionalResidualBlock1D(
local_cond_dim,
down_dims[0],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
)
mid_dim = all_dims[-1]
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True))
# Unet encoder.
self.down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
self.down_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
film_scale_modulation=film_scale_modulation,
),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
)
)
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
_ConditionalResidualBlock1D(
down_dims[-1],
down_dims[-1],
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
]
)
down_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(in_out):
# Unet decoder.
self.up_modules = nn.ModuleList([])
for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
down_modules.append(
self.up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_in,
_ConditionalResidualBlock1D(
dim_in * 2, # x2 as it takes the encoder's skip connection as well
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
ConditionalResidualBlock1D(
_ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
film_scale_modulation=film_scale_modulation,
),
Downsample1d(dim_out) if not is_last else nn.Identity(),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
)
)
up_modules = nn.ModuleList([])
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (len(in_out) - 1)
up_modules.append(
nn.ModuleList(
[
ConditionalResidualBlock1D(
dim_out * 2,
dim_in,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
),
Upsample1d(dim_in) if not is_last else nn.Identity(),
]
)
)
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
self.final_conv = nn.Sequential(
_Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size),
nn.Conv1d(down_dims[0], input_dim, 1),
)
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
local_cond=None,
global_cond=None,
**kwargs,
):
def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor:
"""
x: (B,T,input_dim)
timestep: (B,) or int, diffusion step
local_cond: (B,T,local_cond_dim)
global_cond: (B,global_cond_dim)
output: (B,T,input_dim)
Args:
x: (B, T, input_dim) tensor for input to the Unet.
timestep: (B,) tensor of (timestep_we_are_denoising_from - 1).
local_cond: (B, T, local_cond_dim)
global_cond: (B, global_cond_dim)
output: (B, T, input_dim)
Returns:
(B, T, input_dim)
"""
sample = einops.rearrange(sample, "b h t -> b t h")
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
if global_cond is not None:
global_feature = torch.cat([global_feature, global_cond], axis=-1)
# encode local features
h_local = []
# For 1D convolutions we'll need feature dimension first.
x = einops.rearrange(x, "b t d -> b d t")
if local_cond is not None:
local_cond = einops.rearrange(local_cond, "b h t -> b t h")
resnet, resnet2 = self.local_cond_encoder
x = resnet(local_cond, global_feature)
h_local.append(x)
x = resnet2(local_cond, global_feature)
h_local.append(x)
if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None:
raise ValueError(
"`local_cond` was provided but the relevant encoders weren't built at initialization."
)
local_cond = einops.rearrange(local_cond, "b t d -> b d t")
x = sample
h = []
timesteps_embed = self.diffusion_step_encoder(timestep)
# If there is a global conditioning feature, concatenate it to the timestep embedding.
if global_cond is not None:
global_feature = torch.cat([timesteps_embed, global_cond], axis=-1)
else:
global_feature = timesteps_embed
encoder_skip_features: list[Tensor] = []
for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules):
x = resnet(x, global_feature)
if idx == 0 and len(h_local) > 0:
x = x + h_local[0]
if idx == 0 and local_cond is not None:
x = x + self.local_cond_down_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
h.append(x)
encoder_skip_features.append(x)
x = downsample(x)
for mid_module in self.mid_modules:
x = mid_module(x, global_feature)
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = torch.cat((x, encoder_skip_features.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
# Note: The condition in the original implementation is:
# if idx == len(self.up_modules) and local_cond is not None:
# But as they mention in their comments, this is incorrect. We use the correct condition here.
if idx == (len(self.up_modules) - 1) and local_cond is not None:
x = x + self.local_cond_up_encoder(local_cond, global_feature)
x = resnet2(x, global_feature)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, "b t h -> b h t")
x = einops.rearrange(x, "b d t -> b t d")
return x