Files
lerobot_piper/lerobot/configs/policies.py
Simon Alibert 3c0a209f9f Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
2025-01-31 13:57:37 +01:00

146 lines
5.2 KiB
Python

import abc
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Type, TypeVar
import draccus
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof
T = TypeVar("T", bound="PreTrainedConfig")
@dataclass
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
"""
Base configuration class for policy models.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
input_shapes: A dictionary defining the shapes of the input data for the policy.
output_shapes: A dictionary defining the shapes of the output data for the policy.
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
normalization mode to apply.
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
the original scale.
"""
n_obs_steps: int = 1
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
def __post_init__(self):
self.pretrained_path = None
@property
def type(self) -> str:
return self.get_choice_name(self.__class__)
@abc.abstractproperty
def observation_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractproperty
def action_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractproperty
def reward_delta_indices(self) -> list | None:
raise NotImplementedError
@abc.abstractmethod
def get_optimizer_preset(self) -> OptimizerConfig:
raise NotImplementedError
@abc.abstractmethod
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
raise NotImplementedError
@abc.abstractmethod
def validate_features(self) -> None:
raise NotImplementedError
@property
def robot_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.STATE:
return ft
return None
@property
def env_state_feature(self) -> PolicyFeature | None:
for _, ft in self.input_features.items():
if ft.type is FeatureType.ENV:
return ft
return None
@property
def image_features(self) -> dict[str, PolicyFeature]:
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
@property
def action_feature(self) -> PolicyFeature | None:
for _, ft in self.output_features.items():
if ft.type is FeatureType.ACTION:
return ft
return None
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
force_download: bool = False,
resume_download: bool = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
**policy_kwargs,
) -> T:
model_id = str(pretrained_name_or_path)
config_file: str | None = None
if Path(model_id).is_dir():
if CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
else:
try:
config_file = hf_hub_download(
repo_id=model_id,
filename=CONFIG_NAME,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
) from e
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
# something like --policy.path (in addition to --policy.type)
cli_overrides = policy_kwargs.pop("cli_overrides", [])
return draccus.parse(cls, config_file, args=cli_overrides)