From 57e0bdb49103530c4d206af5b2b3b5a66034dd37 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 8 Mar 2025 12:42:43 +0100 Subject: [PATCH] Move make_policy_config --- lerobot/common/policies/factory.py | 25 --------------------- tests/scripts/save_policy_to_safetensors.py | 2 +- tests/test_datasets.py | 3 +-- tests/test_policies.py | 7 +++--- tests/utils.py | 6 +++++ 5 files changed, 11 insertions(+), 32 deletions(-) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 05a8026e3..7faab186b 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -54,31 +54,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy: raise NotImplementedError(f"Policy with name {name} is not implemented.") -def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: - if policy_type == "act": - from .act import ACTConfig - - return ACTConfig(**kwargs) - elif policy_type == "diffusion": - from .diffusion import DiffusionConfig - - return DiffusionConfig(**kwargs) - elif policy_type == "pi0": - from .pi0 import PI0Config - - return PI0Config(**kwargs) - elif policy_type == "tdmpc": - from .tdmpc import TDMPCConfig - - return TDMPCConfig(**kwargs) - elif policy_type == "vqbet": - from .vqbet import VQBeTConfig - - return VQBeTConfig(**kwargs) - else: - raise ValueError(f"Policy type '{policy_type}' is not available.") - - def make_policy( cfg: PreTrainedConfig, ds_meta: LeRobotDatasetMetadata | None = None, diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 4f61bbd4e..ed04b5ce2 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -22,10 +22,10 @@ from safetensors.torch import save_file from lerobot.common.datasets.factory import make_dataset from lerobot.common.optim.factory import make_optimizer_and_scheduler from lerobot.common.policies import make_policy -from lerobot.common.policies.factory import make_policy_config from lerobot.common.utils.random_utils import set_seed from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig +from tests.utils import make_policy_config def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 0deacebab..9b63b19ab 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -40,12 +40,11 @@ from lerobot.common.datasets.utils import ( unflatten_dict, ) from lerobot.common.envs.factory import make_env_config -from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID -from tests.utils import require_x86_64_kernel +from tests.utils import make_policy_config, require_x86_64_kernel @pytest.fixture diff --git a/tests/test_policies.py b/tests/test_policies.py index 7df79c1d7..3e7591bf0 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -28,11 +28,11 @@ from lerobot.common.datasets.utils import cycle, dataset_to_policy_features from lerobot.common.envs.factory import make_env, make_env_config from lerobot.common.envs.utils import preprocess_observation from lerobot.common.optim.factory import make_optimizer_and_scheduler +from lerobot.common.policies.act import ACTConfig from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.common.policies.factory import ( get_policy_class, make_policy, - make_policy_config, ) from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -41,7 +41,7 @@ from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from tests.scripts.save_policy_to_safetensors import get_policy_stats -from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel +from tests.utils import DEVICE, make_policy_config, require_cpu, require_env, require_x86_64_kernel @pytest.fixture @@ -208,11 +208,10 @@ def test_act_backbone_lr(): """ Test that the ACT policy can be instantiated with a different learning rate for the backbone. """ - cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), - policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), + policy=ACTConfig(optimizer_lr=0.01, optimizer_lr_backbone=0.001), ) cfg.validate() # Needed for auto-setting some parameters diff --git a/tests/utils.py b/tests/utils.py index c49b5b9ff..5f01d0913 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,6 +23,7 @@ import pytest import torch from lerobot import available_cameras, available_motors, available_robots +from lerobot.common.policies import PreTrainedConfig from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device from lerobot.common.robot_devices.motors.utils import MotorsBus @@ -329,3 +330,8 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: else: raise ValueError(f"The motor type '{motor_type}' is not valid.") + + +def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: + policy_cfg_cls = PreTrainedConfig.get_choice_class(policy_type) + return policy_cfg_cls(**kwargs)