Follow transformers single file naming conventions (#124)

This commit is contained in:
Alexander Soare
2024-05-01 13:09:42 +01:00
committed by GitHub
parent 986583dc5c
commit 01d5490d44
6 changed files with 62 additions and 58 deletions

View File

@@ -63,14 +63,14 @@ class DiffusionPolicy(nn.Module):
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.diffusion = _DiffusionUnetImagePolicy(cfg)
self.diffusion = DiffusionModel(cfg)
# TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None
self.ema = None
if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = _EMA(cfg, model=self.ema_diffusion)
self.ema = DiffusionEMA(cfg, model=self.ema_diffusion)
def reset(self):
"""
@@ -152,13 +152,13 @@ class DiffusionPolicy(nn.Module):
assert len(unexpected_keys) == 0
class _DiffusionUnetImagePolicy(nn.Module):
class DiffusionModel(nn.Module):
def __init__(self, cfg: DiffusionConfig):
super().__init__()
self.cfg = cfg
self.rgb_encoder = _RgbEncoder(cfg)
self.unet = _ConditionalUnet1D(
self.rgb_encoder = DiffusionRgbEncoder(cfg)
self.unet = DiffusionConditionalUnet1d(
cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
)
@@ -300,7 +300,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
return loss.mean()
class _RgbEncoder(nn.Module):
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
@@ -403,7 +403,7 @@ def _replace_submodules(
return root_module
class _SinusoidalPosEmb(nn.Module):
class DiffusionSinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int):
@@ -420,7 +420,7 @@ class _SinusoidalPosEmb(nn.Module):
return emb
class _Conv1dBlock(nn.Module):
class DiffusionConv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
@@ -436,7 +436,7 @@ class _Conv1dBlock(nn.Module):
return self.block(x)
class _ConditionalUnet1D(nn.Module):
class DiffusionConditionalUnet1d(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code.
@@ -449,7 +449,7 @@ class _ConditionalUnet1D(nn.Module):
# Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim),
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
nn.Mish(),
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
@@ -477,8 +477,8 @@ class _ConditionalUnet1D(nn.Module):
self.down_modules.append(
nn.ModuleList(
[
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
]
@@ -488,8 +488,12 @@ class _ConditionalUnet1D(nn.Module):
# Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList(
[
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
),
]
)
@@ -501,8 +505,8 @@ class _ConditionalUnet1D(nn.Module):
nn.ModuleList(
[
# dim_in * 2, because it takes the encoder's skip connection as well
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
]
@@ -510,7 +514,7 @@ class _ConditionalUnet1D(nn.Module):
)
self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
)
@@ -559,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
return x
class _ConditionalResidualBlock1D(nn.Module):
class DiffusionConditionalResidualBlock1d(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__(
@@ -578,13 +582,13 @@ class _ConditionalResidualBlock1D(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed).
self.residual_conv = (
@@ -617,7 +621,7 @@ class _ConditionalResidualBlock1D(nn.Module):
return out
class _EMA:
class DiffusionEMA:
"""
Exponential Moving Average of models weights
"""