Move make_policy_config

This commit is contained in:
Simon Alibert
2025-03-08 12:42:43 +01:00
parent 00698305a3
commit 57e0bdb491
5 changed files with 11 additions and 32 deletions

View File

@@ -54,31 +54,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "act":
from .act import ACTConfig
return ACTConfig(**kwargs)
elif policy_type == "diffusion":
from .diffusion import DiffusionConfig
return DiffusionConfig(**kwargs)
elif policy_type == "pi0":
from .pi0 import PI0Config
return PI0Config(**kwargs)
elif policy_type == "tdmpc":
from .tdmpc import TDMPCConfig
return TDMPCConfig(**kwargs)
elif policy_type == "vqbet":
from .vqbet import VQBeTConfig
return VQBeTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy(
cfg: PreTrainedConfig,
ds_meta: LeRobotDatasetMetadata | None = None,

View File

@@ -22,10 +22,10 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies import make_policy
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.utils import make_policy_config
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):

View File

@@ -40,12 +40,11 @@ from lerobot.common.datasets.utils import (
unflatten_dict,
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import require_x86_64_kernel
from tests.utils import make_policy_config, require_x86_64_kernel
@pytest.fixture

View File

@@ -28,11 +28,11 @@ from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
from lerobot.common.envs.factory import make_env, make_env_config
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
from lerobot.common.policies.factory import (
get_policy_class,
make_policy,
make_policy_config,
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -41,7 +41,7 @@ from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
from tests.utils import DEVICE, make_policy_config, require_cpu, require_env, require_x86_64_kernel
@pytest.fixture
@@ -208,11 +208,10 @@ def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
"""
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
policy=ACTConfig(optimizer_lr=0.01, optimizer_lr_backbone=0.001),
)
cfg.validate() # Needed for auto-setting some parameters

View File

@@ -23,6 +23,7 @@ import pytest
import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
@@ -329,3 +330,8 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
policy_cfg_cls = PreTrainedConfig.get_choice_class(policy_type)
return policy_cfg_cls(**kwargs)