Compare commits

...

1 Commits

Author SHA1 Message Date
Quentin Gallouédec
783a40c9d4 pretrained config for act 2024-04-25 16:06:57 +02:00

View File

@@ -1,8 +1,7 @@
from dataclasses import dataclass, field from transformers.configuration_utils import PretrainedConfig
@dataclass class ActionChunkingTransformerConfig(PretrainedConfig):
class ActionChunkingTransformerConfig:
"""Configuration class for the Action Chunking Transformers policy. """Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@@ -55,37 +54,41 @@ class ActionChunkingTransformerConfig:
dropout: Dropout to use in the transformer layers (see code for details). dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`. is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
Example:
```python
>>> from lerobot import ActionChunkingTransformerConfig
>>> # Initializing an ACT style configuration
>>> configuration = ActionChunkingTransformerConfig()
>>> # Initializing a model (with random weights) from the ACT style configuration
>>> model = ActionChunkingTransformerPolicy(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
# Input / output structure. # Input / output structure.
n_obs_steps: int = 1 n_obs_steps: int = 1
chunk_size: int = 100 chunk_size: int = 100
n_action_steps: int = 100 n_action_steps: int = 100
input_shapes: dict[str, list[str]] = field( input_shapes: dict[str, list[str]] = {
default_factory=lambda: { "observation.images.top": [3, 480, 640],
"observation.images.top": [3, 480, 640], "observation.state": [14],
"observation.state": [14], }
}
) output_shapes: dict[str, list[str]] = {"action": [14]}
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization # Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field( normalize_input_modes: dict[str, str] = {
default_factory=lambda: { "observation.image": "mean_std",
"observation.image": "mean_std", "observation.state": "mean_std",
"observation.state": "mean_std", }
}
) unnormalize_output_modes: dict[str, str] = {"action": "mean_std"}
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
# Architecture. # Architecture.
# Vision backbone. # Vision backbone.