167 lines
8.8 KiB
Python
167 lines
8.8 KiB
Python
#!/usr/bin/env python
|
||
|
||
# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
from dataclasses import dataclass, field
|
||
|
||
|
||
@dataclass
|
||
class ACTConfig:
|
||
"""Configuration class for the Action Chunking Transformers policy.
|
||
|
||
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
|
||
|
||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||
Those are: `input_shapes` and 'output_shapes`.
|
||
|
||
Notes on the inputs and outputs:
|
||
- At least one key starting with "observation.image is required as an input.
|
||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||
views. Right now we only support all images having the same shape.
|
||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||
- "action" is required as an output key.
|
||
|
||
Args:
|
||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||
current step and additional steps going back).
|
||
chunk_size: The size of the action prediction "chunks" in units of environment steps.
|
||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||
environment, and throws the other 50 out.
|
||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||
include batch dimension or temporal dimension.
|
||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||
[-1, 1] range.
|
||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||
original scale. Note that this is also used for normalizing the training targets.
|
||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||
pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.
|
||
`None` means no pretrained weights.
|
||
replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated
|
||
convolution.
|
||
pre_norm: Whether to use "pre-norm" in the transformer blocks.
|
||
dim_model: The transformer blocks' main hidden dimension.
|
||
n_heads: The number of heads to use in the transformer blocks' multi-head attention.
|
||
dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward
|
||
layers.
|
||
feedforward_activation: The activation to use in the transformer block's feed-forward layers.
|
||
n_encoder_layers: The number of transformer layers to use for the transformer encoder.
|
||
n_decoder_layers: The number of transformer layers to use for the transformer decoder.
|
||
use_vae: Whether to use a variational objective during training. This introduces another transformer
|
||
which is used as the VAE's encoder (not to be confused with the transformer encoder - see
|
||
documentation in the policy class).
|
||
latent_dim: The VAE's latent dimension.
|
||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||
dropout: Dropout to use in the transformer layers (see code for details).
|
||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||
"""
|
||
|
||
# Input / output structure.
|
||
n_obs_steps: int = 1
|
||
chunk_size: int = 100
|
||
n_action_steps: int = 100
|
||
|
||
input_shapes: dict[str, list[int]] = field(
|
||
default_factory=lambda: {
|
||
"observation.images.top": [3, 480, 640],
|
||
"observation.state": [14],
|
||
}
|
||
)
|
||
output_shapes: dict[str, list[int]] = field(
|
||
default_factory=lambda: {
|
||
"action": [14],
|
||
}
|
||
)
|
||
|
||
# Normalization / Unnormalization
|
||
input_normalization_modes: dict[str, str] = field(
|
||
default_factory=lambda: {
|
||
"observation.images.top": "mean_std",
|
||
"observation.state": "mean_std",
|
||
}
|
||
)
|
||
output_normalization_modes: dict[str, str] = field(
|
||
default_factory=lambda: {
|
||
"action": "mean_std",
|
||
}
|
||
)
|
||
|
||
# Architecture.
|
||
# Vision backbone.
|
||
vision_backbone: str = "resnet18"
|
||
pretrained_backbone_weights: str | None = "ResNet18_Weights.IMAGENET1K_V1"
|
||
replace_final_stride_with_dilation: int = False
|
||
# Transformer layers.
|
||
pre_norm: bool = False
|
||
dim_model: int = 512
|
||
n_heads: int = 8
|
||
dim_feedforward: int = 3200
|
||
feedforward_activation: str = "relu"
|
||
n_encoder_layers: int = 4
|
||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||
# As a consequence we also remove the final, unused layer normalization, by default
|
||
n_decoder_layers: int = 1
|
||
decoder_norm: bool = True
|
||
# VAE.
|
||
use_vae: bool = True
|
||
latent_dim: int = 32
|
||
n_vae_encoder_layers: int = 4
|
||
|
||
# Inference.
|
||
temporal_ensemble_momentum: float | None = None
|
||
|
||
# Training and loss computation.
|
||
dropout: float = 0.1
|
||
kl_weight: float = 10.0
|
||
|
||
def __post_init__(self):
|
||
"""Input validation (not exhaustive)."""
|
||
if not self.vision_backbone.startswith("resnet"):
|
||
raise ValueError(
|
||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||
)
|
||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||
raise NotImplementedError(
|
||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||
"because the policy needs to be queried every step to compute the ensembled action."
|
||
)
|
||
if self.n_action_steps > self.chunk_size:
|
||
raise ValueError(
|
||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||
)
|
||
if self.n_obs_steps != 1:
|
||
raise ValueError(
|
||
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
|
||
)
|