Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -15,9 +15,14 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("act")
@dataclass
class ACTConfig:
class ACTConfig(PreTrainedConfig):
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@@ -90,28 +95,11 @@ class ACTConfig:
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": "mean_std",
"observation.state": "mean_std",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
@@ -144,7 +132,14 @@ class ACTConfig:
dropout: float = 0.1
kl_weight: float = 10.0
# Training preset
optimizer_lr: float = 1e-5
optimizer_weight_decay: float = 1e-4
optimizer_lr_backbone: float = 1e-5
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
@@ -164,8 +159,28 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None