Follow transformers single file naming conventions (#124)
This commit is contained in:
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ActionChunkingTransformerConfig:
|
class ACTConfig:
|
||||||
"""Configuration class for the Action Chunking Transformers policy.
|
"""Configuration class for the Action Chunking Transformers policy.
|
||||||
|
|
||||||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||||||
|
|||||||
@@ -18,11 +18,11 @@ from torch import Tensor, nn
|
|||||||
from torchvision.models._utils import IntermediateLayerGetter
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
|
||||||
|
|
||||||
class ActionChunkingTransformerPolicy(nn.Module):
|
class ACTPolicy(nn.Module):
|
||||||
"""
|
"""
|
||||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||||
@@ -30,7 +30,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
|
|
||||||
name = "act"
|
name = "act"
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
|
def __init__(self, cfg: ACTConfig | None = None, dataset_stats=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
cfg: Policy configuration class instance or None, in which case the default instantiation of the
|
||||||
@@ -38,14 +38,14 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = ActionChunkingTransformerConfig()
|
cfg = ACTConfig()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||||
self.unnormalize_outputs = Unnormalize(
|
self.unnormalize_outputs = Unnormalize(
|
||||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
self.model = _ActionChunkingTransformer(cfg)
|
self.model = ACT(cfg)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
@@ -126,8 +126,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
self.load_state_dict(d)
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
|
||||||
class _ActionChunkingTransformer(nn.Module):
|
class ACT(nn.Module):
|
||||||
"""Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy.
|
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
|
||||||
|
|
||||||
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
|
||||||
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
|
||||||
@@ -161,13 +161,13 @@ class _ActionChunkingTransformer(nn.Module):
|
|||||||
└───────────────────────┘
|
└───────────────────────┘
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
def __init__(self, cfg: ACTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
if self.cfg.use_vae:
|
if self.cfg.use_vae:
|
||||||
self.vae_encoder = _TransformerEncoder(cfg)
|
self.vae_encoder = ACTEncoder(cfg)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
@@ -184,7 +184,7 @@ class _ActionChunkingTransformer(nn.Module):
|
|||||||
# dimension.
|
# dimension.
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
|
create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
@@ -199,8 +199,8 @@ class _ActionChunkingTransformer(nn.Module):
|
|||||||
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
|
||||||
|
|
||||||
# Transformer (acts as VAE decoder when training with the variational objective).
|
# Transformer (acts as VAE decoder when training with the variational objective).
|
||||||
self.encoder = _TransformerEncoder(cfg)
|
self.encoder = ACTEncoder(cfg)
|
||||||
self.decoder = _TransformerDecoder(cfg)
|
self.decoder = ACTDecoder(cfg)
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
@@ -211,7 +211,7 @@ class _ActionChunkingTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
|
||||||
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(cfg.d_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
|
||||||
@@ -341,12 +341,12 @@ class _ActionChunkingTransformer(nn.Module):
|
|||||||
return actions, (mu, log_sigma_x2)
|
return actions, (mu, log_sigma_x2)
|
||||||
|
|
||||||
|
|
||||||
class _TransformerEncoder(nn.Module):
|
class ACTEncoder(nn.Module):
|
||||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||||
|
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
def __init__(self, cfg: ACTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
self.layers = nn.ModuleList([ACTEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
||||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
||||||
@@ -356,8 +356,8 @@ class _TransformerEncoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class _TransformerEncoderLayer(nn.Module):
|
class ACTEncoderLayer(nn.Module):
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
def __init__(self, cfg: ACTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||||
|
|
||||||
@@ -371,7 +371,7 @@ class _TransformerEncoderLayer(nn.Module):
|
|||||||
self.dropout1 = nn.Dropout(cfg.dropout)
|
self.dropout1 = nn.Dropout(cfg.dropout)
|
||||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||||
|
|
||||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
self.activation = get_activation_fn(cfg.feedforward_activation)
|
||||||
self.pre_norm = cfg.pre_norm
|
self.pre_norm = cfg.pre_norm
|
||||||
|
|
||||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
||||||
@@ -394,11 +394,11 @@ class _TransformerEncoderLayer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class _TransformerDecoder(nn.Module):
|
class ACTDecoder(nn.Module):
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
def __init__(self, cfg: ACTConfig):
|
||||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
self.layers = nn.ModuleList([ACTDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
||||||
self.norm = nn.LayerNorm(cfg.d_model)
|
self.norm = nn.LayerNorm(cfg.d_model)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -417,8 +417,8 @@ class _TransformerDecoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class _TransformerDecoderLayer(nn.Module):
|
class ACTDecoderLayer(nn.Module):
|
||||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
def __init__(self, cfg: ACTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||||
@@ -435,7 +435,7 @@ class _TransformerDecoderLayer(nn.Module):
|
|||||||
self.dropout2 = nn.Dropout(cfg.dropout)
|
self.dropout2 = nn.Dropout(cfg.dropout)
|
||||||
self.dropout3 = nn.Dropout(cfg.dropout)
|
self.dropout3 = nn.Dropout(cfg.dropout)
|
||||||
|
|
||||||
self.activation = _get_activation_fn(cfg.feedforward_activation)
|
self.activation = get_activation_fn(cfg.feedforward_activation)
|
||||||
self.pre_norm = cfg.pre_norm
|
self.pre_norm = cfg.pre_norm
|
||||||
|
|
||||||
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
|
||||||
@@ -489,7 +489,7 @@ class _TransformerDecoderLayer(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
|
def create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
|
||||||
"""1D sinusoidal positional embeddings as in Attention is All You Need.
|
"""1D sinusoidal positional embeddings as in Attention is All You Need.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -507,7 +507,7 @@ def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) ->
|
|||||||
return torch.from_numpy(sinusoid_table).float()
|
return torch.from_numpy(sinusoid_table).float()
|
||||||
|
|
||||||
|
|
||||||
class _SinusoidalPositionEmbedding2D(nn.Module):
|
class ACTSinusoidalPositionEmbedding2d(nn.Module):
|
||||||
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
|
||||||
|
|
||||||
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
|
||||||
@@ -561,7 +561,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
|
|||||||
return pos_embed
|
return pos_embed
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation: str) -> Callable:
|
def get_activation_fn(activation: str) -> Callable:
|
||||||
"""Return an activation function given a string."""
|
"""Return an activation function given a string."""
|
||||||
if activation == "relu":
|
if activation == "relu":
|
||||||
return F.relu
|
return F.relu
|
||||||
|
|||||||
@@ -63,14 +63,14 @@ class DiffusionPolicy(nn.Module):
|
|||||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
self._queues = None
|
self._queues = None
|
||||||
|
|
||||||
self.diffusion = _DiffusionUnetImagePolicy(cfg)
|
self.diffusion = DiffusionModel(cfg)
|
||||||
|
|
||||||
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
# TODO(alexander-soare): This should probably be managed outside of the policy class.
|
||||||
self.ema_diffusion = None
|
self.ema_diffusion = None
|
||||||
self.ema = None
|
self.ema = None
|
||||||
if self.cfg.use_ema:
|
if self.cfg.use_ema:
|
||||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
self.ema = DiffusionEMA(cfg, model=self.ema_diffusion)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
@@ -152,13 +152,13 @@ class DiffusionPolicy(nn.Module):
|
|||||||
assert len(unexpected_keys) == 0
|
assert len(unexpected_keys) == 0
|
||||||
|
|
||||||
|
|
||||||
class _DiffusionUnetImagePolicy(nn.Module):
|
class DiffusionModel(nn.Module):
|
||||||
def __init__(self, cfg: DiffusionConfig):
|
def __init__(self, cfg: DiffusionConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
|
||||||
self.rgb_encoder = _RgbEncoder(cfg)
|
self.rgb_encoder = DiffusionRgbEncoder(cfg)
|
||||||
self.unet = _ConditionalUnet1D(
|
self.unet = DiffusionConditionalUnet1d(
|
||||||
cfg,
|
cfg,
|
||||||
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
|
||||||
)
|
)
|
||||||
@@ -300,7 +300,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
|
|||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
class _RgbEncoder(nn.Module):
|
class DiffusionRgbEncoder(nn.Module):
|
||||||
"""Encoder an RGB image into a 1D feature vector.
|
"""Encoder an RGB image into a 1D feature vector.
|
||||||
|
|
||||||
Includes the ability to normalize and crop the image first.
|
Includes the ability to normalize and crop the image first.
|
||||||
@@ -403,7 +403,7 @@ def _replace_submodules(
|
|||||||
return root_module
|
return root_module
|
||||||
|
|
||||||
|
|
||||||
class _SinusoidalPosEmb(nn.Module):
|
class DiffusionSinusoidalPosEmb(nn.Module):
|
||||||
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
"""1D sinusoidal positional embeddings as in Attention is All You Need."""
|
||||||
|
|
||||||
def __init__(self, dim: int):
|
def __init__(self, dim: int):
|
||||||
@@ -420,7 +420,7 @@ class _SinusoidalPosEmb(nn.Module):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
class _Conv1dBlock(nn.Module):
|
class DiffusionConv1dBlock(nn.Module):
|
||||||
"""Conv1d --> GroupNorm --> Mish"""
|
"""Conv1d --> GroupNorm --> Mish"""
|
||||||
|
|
||||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||||
@@ -436,7 +436,7 @@ class _Conv1dBlock(nn.Module):
|
|||||||
return self.block(x)
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
class _ConditionalUnet1D(nn.Module):
|
class DiffusionConditionalUnet1d(nn.Module):
|
||||||
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
"""A 1D convolutional UNet with FiLM modulation for conditioning.
|
||||||
|
|
||||||
Note: this removes local conditioning as compared to the original diffusion policy code.
|
Note: this removes local conditioning as compared to the original diffusion policy code.
|
||||||
@@ -449,7 +449,7 @@ class _ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
# Encoder for the diffusion timestep.
|
# Encoder for the diffusion timestep.
|
||||||
self.diffusion_step_encoder = nn.Sequential(
|
self.diffusion_step_encoder = nn.Sequential(
|
||||||
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim),
|
||||||
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
|
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
|
||||||
nn.Mish(),
|
nn.Mish(),
|
||||||
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
|
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
|
||||||
@@ -477,8 +477,8 @@ class _ConditionalUnet1D(nn.Module):
|
|||||||
self.down_modules.append(
|
self.down_modules.append(
|
||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs),
|
DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
|
||||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||||
# Downsample as long as it is not the last block.
|
# Downsample as long as it is not the last block.
|
||||||
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
@@ -488,8 +488,12 @@ class _ConditionalUnet1D(nn.Module):
|
|||||||
# Processing in the middle of the auto-encoder.
|
# Processing in the middle of the auto-encoder.
|
||||||
self.mid_modules = nn.ModuleList(
|
self.mid_modules = nn.ModuleList(
|
||||||
[
|
[
|
||||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
DiffusionConditionalResidualBlock1d(
|
||||||
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs),
|
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
|
||||||
|
),
|
||||||
|
DiffusionConditionalResidualBlock1d(
|
||||||
|
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -501,8 +505,8 @@ class _ConditionalUnet1D(nn.Module):
|
|||||||
nn.ModuleList(
|
nn.ModuleList(
|
||||||
[
|
[
|
||||||
# dim_in * 2, because it takes the encoder's skip connection as well
|
# dim_in * 2, because it takes the encoder's skip connection as well
|
||||||
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs),
|
DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
|
||||||
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs),
|
DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
|
||||||
# Upsample as long as it is not the last block.
|
# Upsample as long as it is not the last block.
|
||||||
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
|
||||||
]
|
]
|
||||||
@@ -510,7 +514,7 @@ class _ConditionalUnet1D(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.final_conv = nn.Sequential(
|
self.final_conv = nn.Sequential(
|
||||||
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
|
||||||
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -559,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class _ConditionalResidualBlock1D(nn.Module):
|
class DiffusionConditionalResidualBlock1d(nn.Module):
|
||||||
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
"""ResNet style 1D convolutional block with FiLM modulation for conditioning."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -578,13 +582,13 @@ class _ConditionalResidualBlock1D(nn.Module):
|
|||||||
self.use_film_scale_modulation = use_film_scale_modulation
|
self.use_film_scale_modulation = use_film_scale_modulation
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
|
||||||
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
|
||||||
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
|
||||||
|
|
||||||
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
|
||||||
|
|
||||||
# A final convolution for dimension matching the residual (if needed).
|
# A final convolution for dimension matching the residual (if needed).
|
||||||
self.residual_conv = (
|
self.residual_conv = (
|
||||||
@@ -617,7 +621,7 @@ class _ConditionalResidualBlock1D(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class _EMA:
|
class DiffusionEMA:
|
||||||
"""
|
"""
|
||||||
Exponential Moving Average of models weights
|
Exponential Moving Average of models weights
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
|
|||||||
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
|
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
elif hydra_cfg.policy.name == "act":
|
elif hydra_cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
|
|
||||||
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg)
|
policy_cfg = _policy_cfg_from_hydra_cfg(ACTConfig, hydra_cfg)
|
||||||
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats)
|
policy = ACTPolicy(policy_cfg, dataset_stats)
|
||||||
policy.to(get_safe_torch_device(hydra_cfg.device))
|
policy.to(get_safe_torch_device(hydra_cfg.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(hydra_cfg.policy.name)
|
raise ValueError(hydra_cfg.policy.name)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import gymnasium as gym
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
from tests.utils import require_env
|
from tests.utils import require_env
|
||||||
@@ -30,7 +30,7 @@ def test_available_policies():
|
|||||||
consistent with those listed in `lerobot/__init__.py`.
|
consistent with those listed in `lerobot/__init__.py`.
|
||||||
"""
|
"""
|
||||||
policy_classes = [
|
policy_classes = [
|
||||||
ActionChunkingTransformerPolicy,
|
ACTPolicy,
|
||||||
DiffusionPolicy,
|
DiffusionPolicy,
|
||||||
TDMPCPolicy,
|
TDMPCPolicy,
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from lerobot.common.datasets.factory import make_dataset
|
|||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
@@ -115,7 +115,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||||||
new_policy.load_state_dict(policy.state_dict())
|
new_policy.load_state_dict(policy.state_dict())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy])
|
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
|
||||||
def test_policy_defaults(policy_cls):
|
def test_policy_defaults(policy_cls):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.
|
||||||
|
|||||||
Reference in New Issue
Block a user