From 783a40c9d4ed9d59914c5eb9f1107ebcd671f4a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Apr 2024 16:06:57 +0200 Subject: [PATCH] pretrained config for act --- .../common/policies/act/configuration_act.py | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 82280b2c..c54c6809 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -1,8 +1,7 @@ -from dataclasses import dataclass, field +from transformers.configuration_utils import PretrainedConfig -@dataclass -class ActionChunkingTransformerConfig: +class ActionChunkingTransformerConfig(PretrainedConfig): """Configuration class for the Action Chunking Transformers policy. Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". @@ -55,37 +54,41 @@ class ActionChunkingTransformerConfig: dropout: Dropout to use in the transformer layers (see code for details). kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. - """ + + Example: + + ```python + >>> from lerobot import ActionChunkingTransformerConfig + + >>> # Initializing an ACT style configuration + >>> configuration = ActionChunkingTransformerConfig() + + >>> # Initializing a model (with random weights) from the ACT style configuration + >>> model = ActionChunkingTransformerPolicy(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" # Input / output structure. n_obs_steps: int = 1 chunk_size: int = 100 n_action_steps: int = 100 - input_shapes: dict[str, list[str]] = field( - default_factory=lambda: { - "observation.images.top": [3, 480, 640], - "observation.state": [14], - } - ) - output_shapes: dict[str, list[str]] = field( - default_factory=lambda: { - "action": [14], - } - ) + input_shapes: dict[str, list[str]] = { + "observation.images.top": [3, 480, 640], + "observation.state": [14], + } + + output_shapes: dict[str, list[str]] = {"action": [14]} # Normalization / Unnormalization - normalize_input_modes: dict[str, str] = field( - default_factory=lambda: { - "observation.image": "mean_std", - "observation.state": "mean_std", - } - ) - unnormalize_output_modes: dict[str, str] = field( - default_factory=lambda: { - "action": "mean_std", - } - ) + normalize_input_modes: dict[str, str] = { + "observation.image": "mean_std", + "observation.state": "mean_std", + } + + unnormalize_output_modes: dict[str, str] = {"action": "mean_std"} # Architecture. # Vision backbone.