Follow transformers single file naming conventions (#124)

This commit is contained in:
Alexander Soare
2024-05-01 13:09:42 +01:00
committed by GitHub
parent 986583dc5c
commit 01d5490d44
6 changed files with 62 additions and 58 deletions

View File

@@ -18,11 +18,11 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module):
class ACTPolicy(nn.Module):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@@ -30,7 +30,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None):
def __init__(self, cfg: ACTConfig | None = None, dataset_stats=None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -38,14 +38,14 @@ class ActionChunkingTransformerPolicy(nn.Module):
"""
super().__init__()
if cfg is None:
cfg = ActionChunkingTransformerConfig()
cfg = ACTConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
self.model = _ActionChunkingTransformer(cfg)
self.model = ACT(cfg)
def reset(self):
"""This should be called whenever the environment is reset."""
@@ -126,8 +126,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.load_state_dict(d)
class _ActionChunkingTransformer(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy.
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
@@ -161,13 +161,13 @@ class _ActionChunkingTransformer(nn.Module):
└───────────────────────┘
"""
def __init__(self, cfg: ActionChunkingTransformerConfig):
def __init__(self, cfg: ACTConfig):
super().__init__()
self.cfg = cfg
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.cfg.use_vae:
self.vae_encoder = _TransformerEncoder(cfg)
self.vae_encoder = ACTEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(
@@ -184,7 +184,7 @@ class _ActionChunkingTransformer(nn.Module):
# dimension.
self.register_buffer(
"vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
)
# Backbone for image feature extraction.
@@ -199,8 +199,8 @@ class _ActionChunkingTransformer(nn.Module):
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = _TransformerEncoder(cfg)
self.decoder = _TransformerDecoder(cfg)
self.encoder = ACTEncoder(cfg)
self.decoder = ACTDecoder(cfg)
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
@@ -211,7 +211,7 @@ class _ActionChunkingTransformer(nn.Module):
)
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(cfg.d_model // 2)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
@@ -341,12 +341,12 @@ class _ActionChunkingTransformer(nn.Module):
return actions, (mu, log_sigma_x2)
class _TransformerEncoder(nn.Module):
class ACTEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ActionChunkingTransformerConfig):
def __init__(self, cfg: ACTConfig):
super().__init__()
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.layers = nn.ModuleList([ACTEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
@@ -356,8 +356,8 @@ class _TransformerEncoder(nn.Module):
return x
class _TransformerEncoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig):
class ACTEncoderLayer(nn.Module):
def __init__(self, cfg: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
@@ -371,7 +371,7 @@ class _TransformerEncoderLayer(nn.Module):
self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation)
self.activation = get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
@@ -394,11 +394,11 @@ class _TransformerEncoderLayer(nn.Module):
return x
class _TransformerDecoder(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig):
class ACTDecoder(nn.Module):
def __init__(self, cfg: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
self.layers = nn.ModuleList([ACTDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model)
def forward(
@@ -417,8 +417,8 @@ class _TransformerDecoder(nn.Module):
return x
class _TransformerDecoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig):
class ACTDecoderLayer(nn.Module):
def __init__(self, cfg: ACTConfig):
super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
@@ -435,7 +435,7 @@ class _TransformerDecoderLayer(nn.Module):
self.dropout2 = nn.Dropout(cfg.dropout)
self.dropout3 = nn.Dropout(cfg.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation)
self.activation = get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
@@ -489,7 +489,7 @@ class _TransformerDecoderLayer(nn.Module):
return x
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
def create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need.
Args:
@@ -507,7 +507,7 @@ def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) ->
return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module):
class ACTSinusoidalPositionEmbedding2d(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
@@ -561,7 +561,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
return pos_embed
def _get_activation_fn(activation: str) -> Callable:
def get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string."""
if activation == "relu":
return F.relu