Follow transformers single file naming conventions (#124)
This commit is contained in:
@@ -63,14 +63,14 @@ class DiffusionPolicy(nn.Module):
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
self.diffusion = _DiffusionUnetImagePolicy(cfg)
|
||||
self.diffusion = DiffusionModel(cfg)
|
||||
|
||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||
self.ema_diffusion = None
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||
self.ema = DiffusionEMA(cfg, model=self.ema_diffusion)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
@@ -152,13 +152,13 @@ class DiffusionPolicy(nn.Module):
|
||||
assert len(unexpected_keys) == 0
|
||||
|
||||
|
||||
class _DiffusionUnetImagePolicy(nn.Module):
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, cfg: DiffusionConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
self.rgb_encoder = _RgbEncoder(cfg)
|
||||
self.unet = _ConditionalUnet1D(
|
||||
self.rgb_encoder = DiffusionRgbEncoder(cfg)
|
||||
self.unet = DiffusionConditionalUnet1d(
|
||||
cfg,
|
||||
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
||||
)
|
||||
@@ -300,7 +300,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class _RgbEncoder(nn.Module):
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
Includes the ability to normalize and crop the image first.
|
||||
@@ -403,7 +403,7 @@ def _replace_submodules(
|
||||
return root_module
|
||||
|
||||
|
||||
class _SinusoidalPosEmb(nn.Module):
|
||||
class DiffusionSinusoidalPosEmb(nn.Module):
|
||||
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
@@ -420,7 +420,7 @@ class _SinusoidalPosEmb(nn.Module):
|
||||
return emb
|
||||
|
||||
|
||||
class _Conv1dBlock(nn.Module):
|
||||
class DiffusionConv1dBlock(nn.Module):
|
||||
"""Conv1d --> GroupNorm --> Mish"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
@@ -436,7 +436,7 @@ class _Conv1dBlock(nn.Module):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class _ConditionalUnet1D(nn.Module):
|
||||
class DiffusionConditionalUnet1d(nn.Module):
|
||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||
|
||||
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||
@@ -449,7 +449,7 @@ class _ConditionalUnet1D(nn.Module):
|
||||
|
||||
# Encoder for the diffusion timestep.
|
||||
self.diffusion_step_encoder = nn.Sequential(
|
||||
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
||||
DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
||||
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
|
||||
@@ -477,8 +477,8 @@ class _ConditionalUnet1D(nn.Module):
|
||||
self.down_modules.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Downsample as long as it is not the last block.
|
||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
@@ -488,8 +488,12 @@ class _ConditionalUnet1D(nn.Module):
|
||||
# Processing in the middle of the auto-encoder.
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -501,8 +505,8 @@ class _ConditionalUnet1D(nn.Module):
|
||||
nn.ModuleList(
|
||||
[
|
||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||
# Upsample as long as it is not the last block.
|
||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||
]
|
||||
@@ -510,7 +514,7 @@ class _ConditionalUnet1D(nn.Module):
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||
DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
||||
)
|
||||
|
||||
@@ -559,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class _ConditionalResidualBlock1D(nn.Module):
|
||||
class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||
|
||||
def __init__(
|
||||
@@ -578,13 +582,13 @@ class _ConditionalResidualBlock1D(nn.Module):
|
||||
self.use_film_scale_modulation = use_film_scale_modulation
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||
|
||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||
|
||||
# A final convolution for dimension matching the residual (if needed).
|
||||
self.residual_conv = (
|
||||
@@ -617,7 +621,7 @@ class _ConditionalResidualBlock1D(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class _EMA:
|
||||
class DiffusionEMA:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user