Follow transformers single file naming conventions (#124)

This commit is contained in:
Alexander Soare
2024-05-01 13:09:42 +01:00
committed by GitHub
parent 986583dc5c
commit 01d5490d44
6 changed files with 62 additions and 58 deletions

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ActionChunkingTransformerConfig: class ACTConfig:
"""Configuration class for the Action Chunking Transformers policy. """Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".

View File

@@ -18,11 +18,11 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
class ActionChunkingTransformerPolicy(nn.Module): class ACTPolicy(nn.Module):
""" """
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
@@ -30,7 +30,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
name = "act" name = "act"
def __init__(self, cfg: ActionChunkingTransformerConfig | None = None, dataset_stats=None): def __init__(self, cfg: ACTConfig | None = None, dataset_stats=None):
""" """
Args: Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the cfg: Policy configuration class instance or None, in which case the default instantiation of the
@@ -38,14 +38,14 @@ class ActionChunkingTransformerPolicy(nn.Module):
""" """
super().__init__() super().__init__()
if cfg is None: if cfg is None:
cfg = ActionChunkingTransformerConfig() cfg = ACTConfig()
self.cfg = cfg self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats) self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats) self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize( self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
) )
self.model = _ActionChunkingTransformer(cfg) self.model = ACT(cfg)
def reset(self): def reset(self):
"""This should be called whenever the environment is reset.""" """This should be called whenever the environment is reset."""
@@ -126,8 +126,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.load_state_dict(d) self.load_state_dict(d)
class _ActionChunkingTransformer(nn.Module): class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ActionChunkingTransformerPolicy. """Action Chunking Transformer: The underlying neural network for ACTPolicy.
Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.
- The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the
@@ -161,13 +161,13 @@ class _ActionChunkingTransformer(nn.Module):
└───────────────────────┘ └───────────────────────┘
""" """
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, cfg: ACTConfig):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence]. # BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.cfg.use_vae: if self.cfg.use_vae:
self.vae_encoder = _TransformerEncoder(cfg) self.vae_encoder = ACTEncoder(cfg)
self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model) self.vae_encoder_cls_embed = nn.Embedding(1, cfg.d_model)
# Projection layer for joint-space configuration to hidden dimension. # Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear( self.vae_encoder_robot_state_input_proj = nn.Linear(
@@ -184,7 +184,7 @@ class _ActionChunkingTransformer(nn.Module):
# dimension. # dimension.
self.register_buffer( self.register_buffer(
"vae_encoder_pos_enc", "vae_encoder_pos_enc",
_create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0), create_sinusoidal_position_embedding(1 + 1 + cfg.chunk_size, cfg.d_model).unsqueeze(0),
) )
# Backbone for image feature extraction. # Backbone for image feature extraction.
@@ -199,8 +199,8 @@ class _ActionChunkingTransformer(nn.Module):
self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"})
# Transformer (acts as VAE decoder when training with the variational objective). # Transformer (acts as VAE decoder when training with the variational objective).
self.encoder = _TransformerEncoder(cfg) self.encoder = ACTEncoder(cfg)
self.decoder = _TransformerDecoder(cfg) self.decoder = ACTDecoder(cfg)
# Transformer encoder input projections. The tokens will be structured like # Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels]. # [latent, robot_state, image_feature_map_pixels].
@@ -211,7 +211,7 @@ class _ActionChunkingTransformer(nn.Module):
) )
# Transformer encoder positional embeddings. # Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model) self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, cfg.d_model)
self.encoder_cam_feat_pos_embed = _SinusoidalPositionEmbedding2D(cfg.d_model // 2) self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(cfg.d_model // 2)
# Transformer decoder. # Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
@@ -341,12 +341,12 @@ class _ActionChunkingTransformer(nn.Module):
return actions, (mu, log_sigma_x2) return actions, (mu, log_sigma_x2)
class _TransformerEncoder(nn.Module): class ACTEncoder(nn.Module):
"""Convenience module for running multiple encoder layers, maybe followed by normalization.""" """Convenience module for running multiple encoder layers, maybe followed by normalization."""
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, cfg: ACTConfig):
super().__init__() super().__init__()
self.layers = nn.ModuleList([_TransformerEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)]) self.layers = nn.ModuleList([ACTEncoderLayer(cfg) for _ in range(cfg.n_encoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity() self.norm = nn.LayerNorm(cfg.d_model) if cfg.pre_norm else nn.Identity()
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor: def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
@@ -356,8 +356,8 @@ class _TransformerEncoder(nn.Module):
return x return x
class _TransformerEncoderLayer(nn.Module): class ACTEncoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, cfg: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
@@ -371,7 +371,7 @@ class _TransformerEncoderLayer(nn.Module):
self.dropout1 = nn.Dropout(cfg.dropout) self.dropout1 = nn.Dropout(cfg.dropout)
self.dropout2 = nn.Dropout(cfg.dropout) self.dropout2 = nn.Dropout(cfg.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation) self.activation = get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm self.pre_norm = cfg.pre_norm
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor: def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
@@ -394,11 +394,11 @@ class _TransformerEncoderLayer(nn.Module):
return x return x
class _TransformerDecoder(nn.Module): class ACTDecoder(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, cfg: ACTConfig):
"""Convenience module for running multiple decoder layers followed by normalization.""" """Convenience module for running multiple decoder layers followed by normalization."""
super().__init__() super().__init__()
self.layers = nn.ModuleList([_TransformerDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)]) self.layers = nn.ModuleList([ACTDecoderLayer(cfg) for _ in range(cfg.n_decoder_layers)])
self.norm = nn.LayerNorm(cfg.d_model) self.norm = nn.LayerNorm(cfg.d_model)
def forward( def forward(
@@ -417,8 +417,8 @@ class _TransformerDecoder(nn.Module):
return x return x
class _TransformerDecoderLayer(nn.Module): class ACTDecoderLayer(nn.Module):
def __init__(self, cfg: ActionChunkingTransformerConfig): def __init__(self, cfg: ACTConfig):
super().__init__() super().__init__()
self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.self_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout) self.multihead_attn = nn.MultiheadAttention(cfg.d_model, cfg.n_heads, dropout=cfg.dropout)
@@ -435,7 +435,7 @@ class _TransformerDecoderLayer(nn.Module):
self.dropout2 = nn.Dropout(cfg.dropout) self.dropout2 = nn.Dropout(cfg.dropout)
self.dropout3 = nn.Dropout(cfg.dropout) self.dropout3 = nn.Dropout(cfg.dropout)
self.activation = _get_activation_fn(cfg.feedforward_activation) self.activation = get_activation_fn(cfg.feedforward_activation)
self.pre_norm = cfg.pre_norm self.pre_norm = cfg.pre_norm
def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:
@@ -489,7 +489,7 @@ class _TransformerDecoderLayer(nn.Module):
return x return x
def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor: def create_sinusoidal_position_embedding(num_positions: int, dimension: int) -> Tensor:
"""1D sinusoidal positional embeddings as in Attention is All You Need. """1D sinusoidal positional embeddings as in Attention is All You Need.
Args: Args:
@@ -507,7 +507,7 @@ def _create_sinusoidal_position_embedding(num_positions: int, dimension: int) ->
return torch.from_numpy(sinusoid_table).float() return torch.from_numpy(sinusoid_table).float()
class _SinusoidalPositionEmbedding2D(nn.Module): class ACTSinusoidalPositionEmbedding2d(nn.Module):
"""2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.
The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H
@@ -561,7 +561,7 @@ class _SinusoidalPositionEmbedding2D(nn.Module):
return pos_embed return pos_embed
def _get_activation_fn(activation: str) -> Callable: def get_activation_fn(activation: str) -> Callable:
"""Return an activation function given a string.""" """Return an activation function given a string."""
if activation == "relu": if activation == "relu":
return F.relu return F.relu

View File

@@ -63,14 +63,14 @@ class DiffusionPolicy(nn.Module):
# queues are populated during rollout of the policy, they contain the n latest observations and actions # queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None self._queues = None
self.diffusion = _DiffusionUnetImagePolicy(cfg) self.diffusion = DiffusionModel(cfg)
# TODO(alexander-soare): This should probably be managed outside of the policy class. # TODO(alexander-soare): This should probably be managed outside of the policy class.
self.ema_diffusion = None self.ema_diffusion = None
self.ema = None self.ema = None
if self.cfg.use_ema: if self.cfg.use_ema:
self.ema_diffusion = copy.deepcopy(self.diffusion) self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = _EMA(cfg, model=self.ema_diffusion) self.ema = DiffusionEMA(cfg, model=self.ema_diffusion)
def reset(self): def reset(self):
""" """
@@ -152,13 +152,13 @@ class DiffusionPolicy(nn.Module):
assert len(unexpected_keys) == 0 assert len(unexpected_keys) == 0
class _DiffusionUnetImagePolicy(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, cfg: DiffusionConfig): def __init__(self, cfg: DiffusionConfig):
super().__init__() super().__init__()
self.cfg = cfg self.cfg = cfg
self.rgb_encoder = _RgbEncoder(cfg) self.rgb_encoder = DiffusionRgbEncoder(cfg)
self.unet = _ConditionalUnet1D( self.unet = DiffusionConditionalUnet1d(
cfg, cfg,
global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps, global_cond_dim=(cfg.output_shapes["action"][0] + self.rgb_encoder.feature_dim) * cfg.n_obs_steps,
) )
@@ -300,7 +300,7 @@ class _DiffusionUnetImagePolicy(nn.Module):
return loss.mean() return loss.mean()
class _RgbEncoder(nn.Module): class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector. """Encoder an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first. Includes the ability to normalize and crop the image first.
@@ -403,7 +403,7 @@ def _replace_submodules(
return root_module return root_module
class _SinusoidalPosEmb(nn.Module): class DiffusionSinusoidalPosEmb(nn.Module):
"""1D sinusoidal positional embeddings as in Attention is All You Need.""" """1D sinusoidal positional embeddings as in Attention is All You Need."""
def __init__(self, dim: int): def __init__(self, dim: int):
@@ -420,7 +420,7 @@ class _SinusoidalPosEmb(nn.Module):
return emb return emb
class _Conv1dBlock(nn.Module): class DiffusionConv1dBlock(nn.Module):
"""Conv1d --> GroupNorm --> Mish""" """Conv1d --> GroupNorm --> Mish"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
@@ -436,7 +436,7 @@ class _Conv1dBlock(nn.Module):
return self.block(x) return self.block(x)
class _ConditionalUnet1D(nn.Module): class DiffusionConditionalUnet1d(nn.Module):
"""A 1D convolutional UNet with FiLM modulation for conditioning. """A 1D convolutional UNet with FiLM modulation for conditioning.
Note: this removes local conditioning as compared to the original diffusion policy code. Note: this removes local conditioning as compared to the original diffusion policy code.
@@ -449,7 +449,7 @@ class _ConditionalUnet1D(nn.Module):
# Encoder for the diffusion timestep. # Encoder for the diffusion timestep.
self.diffusion_step_encoder = nn.Sequential( self.diffusion_step_encoder = nn.Sequential(
_SinusoidalPosEmb(cfg.diffusion_step_embed_dim), DiffusionSinusoidalPosEmb(cfg.diffusion_step_embed_dim),
nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4),
nn.Mish(), nn.Mish(),
nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim),
@@ -477,8 +477,8 @@ class _ConditionalUnet1D(nn.Module):
self.down_modules.append( self.down_modules.append(
nn.ModuleList( nn.ModuleList(
[ [
_ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Downsample as long as it is not the last block. # Downsample as long as it is not the last block.
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(),
] ]
@@ -488,8 +488,12 @@ class _ConditionalUnet1D(nn.Module):
# Processing in the middle of the auto-encoder. # Processing in the middle of the auto-encoder.
self.mid_modules = nn.ModuleList( self.mid_modules = nn.ModuleList(
[ [
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(
_ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
),
DiffusionConditionalResidualBlock1d(
cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs
),
] ]
) )
@@ -501,8 +505,8 @@ class _ConditionalUnet1D(nn.Module):
nn.ModuleList( nn.ModuleList(
[ [
# dim_in * 2, because it takes the encoder's skip connection as well # dim_in * 2, because it takes the encoder's skip connection as well
_ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs),
_ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs),
# Upsample as long as it is not the last block. # Upsample as long as it is not the last block.
nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(),
] ]
@@ -510,7 +514,7 @@ class _ConditionalUnet1D(nn.Module):
) )
self.final_conv = nn.Sequential( self.final_conv = nn.Sequential(
_Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), DiffusionConv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size),
nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1), nn.Conv1d(cfg.down_dims[0], cfg.output_shapes["action"][0], 1),
) )
@@ -559,7 +563,7 @@ class _ConditionalUnet1D(nn.Module):
return x return x
class _ConditionalResidualBlock1D(nn.Module): class DiffusionConditionalResidualBlock1d(nn.Module):
"""ResNet style 1D convolutional block with FiLM modulation for conditioning.""" """ResNet style 1D convolutional block with FiLM modulation for conditioning."""
def __init__( def __init__(
@@ -578,13 +582,13 @@ class _ConditionalResidualBlock1D(nn.Module):
self.use_film_scale_modulation = use_film_scale_modulation self.use_film_scale_modulation = use_film_scale_modulation
self.out_channels = out_channels self.out_channels = out_channels
self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups)
# FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale.
cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels
self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels))
self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups)
# A final convolution for dimension matching the residual (if needed). # A final convolution for dimension matching the residual (if needed).
self.residual_conv = ( self.residual_conv = (
@@ -617,7 +621,7 @@ class _ConditionalResidualBlock1D(nn.Module):
return out return out
class _EMA: class DiffusionEMA:
""" """
Exponential Moving Average of models weights Exponential Moving Average of models weights
""" """

View File

@@ -38,11 +38,11 @@ def make_policy(hydra_cfg: DictConfig, dataset_stats=None):
policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats) policy = DiffusionPolicy(policy_cfg, hydra_cfg.training.offline_steps, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))
elif hydra_cfg.policy.name == "act": elif hydra_cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) policy_cfg = _policy_cfg_from_hydra_cfg(ACTConfig, hydra_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg, dataset_stats) policy = ACTPolicy(policy_cfg, dataset_stats)
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))
else: else:
raise ValueError(hydra_cfg.policy.name) raise ValueError(hydra_cfg.policy.name)

View File

@@ -4,7 +4,7 @@ import gymnasium as gym
import pytest import pytest
import lerobot import lerobot
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from tests.utils import require_env from tests.utils import require_env
@@ -30,7 +30,7 @@ def test_available_policies():
consistent with those listed in `lerobot/__init__.py`. consistent with those listed in `lerobot/__init__.py`.
""" """
policy_classes = [ policy_classes = [
ActionChunkingTransformerPolicy, ACTPolicy,
DiffusionPolicy, DiffusionPolicy,
TDMPCPolicy, TDMPCPolicy,
] ]

View File

@@ -5,7 +5,7 @@ from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.factory import make_policy from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.normalize import Normalize, Unnormalize
@@ -115,7 +115,7 @@ def test_policy(env_name, policy_name, extra_overrides):
new_policy.load_state_dict(policy.state_dict()) new_policy.load_state_dict(policy.state_dict())
@pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ActionChunkingTransformerPolicy]) @pytest.mark.parametrize("policy_cls", [DiffusionPolicy, ACTPolicy])
def test_policy_defaults(policy_cls): def test_policy_defaults(policy_cls):
kwargs = {} kwargs = {}
# TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP. # TODO(alexander-soare): Remove this kwargs hack when we move the scheduler out of DP.