forked from tangger/lerobot
revision 1
This commit is contained in:
@@ -20,10 +20,10 @@ from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ActConfig
|
||||
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
|
||||
|
||||
|
||||
class ActPolicy(nn.Module):
|
||||
class ActionChunkingTransformerPolicy(nn.Module):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
@@ -61,9 +61,11 @@ class ActPolicy(nn.Module):
|
||||
"""
|
||||
|
||||
name = "act"
|
||||
_multiple_obs_steps_not_handled_msg = "ActPolicy does not handle multiple observation steps."
|
||||
_multiple_obs_steps_not_handled_msg = (
|
||||
"ActionChunkingTransformerPolicy does not handle multiple observation steps."
|
||||
)
|
||||
|
||||
def __init__(self, cfg: ActConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
"""
|
||||
TODO(alexander-soare): Add documentation for all parameters once we have model configs established.
|
||||
"""
|
||||
@@ -398,7 +400,7 @@ class ActPolicy(nn.Module):
|
||||
class _TransformerEncoder(nn.Module):
|
||||
"""Convenience module for running multiple encoder layers, maybe followed by normalization."""
|
||||
|
||||
def __init__(self, cfg: ActConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
|
||||
@@ -411,7 +413,7 @@ class _TransformerEncoder(nn.Module):
|
||||
|
||||
|
||||
class _TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, cfg: ActConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
|
||||
@@ -449,7 +451,7 @@ class _TransformerEncoderLayer(nn.Module):
|
||||
|
||||
|
||||
class _TransformerDecoder(nn.Module):
|
||||
def __init__(self, cfg: ActConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
|
||||
@@ -472,7 +474,7 @@ class _TransformerDecoder(nn.Module):
|
||||
|
||||
|
||||
class _TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, cfg: ActConfig):
|
||||
def __init__(self, cfg: ActionChunkingTransformerConfig):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
|
||||
|
||||
Reference in New Issue
Block a user