Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -29,32 +29,27 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class ACTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "act"],
):
class ACTPolicy(PreTrainedPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
"""
config_class = ACTConfig
name = "act"
def __init__(
self,
config: ACTConfig | None = None,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -64,30 +59,46 @@ class ACTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = ACTConfig()
self.config: ACTConfig = config
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
def get_optim_params(self) -> dict:
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
# Should we remove this and just `return self.parameters()`?
return [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": self.config.optimizer_lr_backbone,
},
]
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None:
@@ -106,9 +117,11 @@ class ACTPolicy(
self.eval()
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -134,9 +147,11 @@ class ACTPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -288,31 +303,30 @@ class ACT(nn.Module):
"""
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
super().__init__()
self.config = config
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_robot_state:
if self.config.robot_state_feature:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.config.robot_state_feature.shape[0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
config.output_shapes["action"][0], config.dim_model
self.config.action_feature.shape[0],
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_robot_state:
if self.config.robot_state_feature:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
@@ -320,7 +334,7 @@ class ACT(nn.Module):
)
# Backbone for image feature extraction.
if self.use_images:
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
@@ -337,27 +351,27 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_robot_state:
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.config.robot_state_feature.shape[0], config.dim_model
)
if self.use_env_state:
if self.config.env_state_feature:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
self.config.env_state_feature.shape[0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
if self.config.image_features:
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent
if self.use_robot_state:
if self.config.robot_state_feature:
n_1d_tokens += 1
if self.use_env_state:
if self.config.env_state_feature:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@@ -365,7 +379,7 @@ class ACT(nn.Module):
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self._reset_parameters()
@@ -380,13 +394,13 @@ class ACT(nn.Module):
`batch` should have the following structure:
{
"observation.state" (optional): (B, state_dim) batch of robot states.
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
[image_features]: (B, n_cameras, C, H, W) batch of images.
AND/OR
"observation.environment_state": (B, env_dim) batch of environment states.
[env_state_feature]: (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
@@ -411,12 +425,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_robot_state:
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_robot_state:
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@@ -430,7 +444,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_robot_state else 1),
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
device=batch["observation.state"].device,
)
@@ -463,16 +477,16 @@ class ACT(nn.Module):
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.use_images:
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []