Move make_policy_config

This commit is contained in:
Simon Alibert
2025-03-08 12:42:43 +01:00
parent 00698305a3
commit 57e0bdb491
5 changed files with 11 additions and 32 deletions

View File

@@ -23,6 +23,7 @@ import pytest
import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.policies import PreTrainedConfig
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
@@ -329,3 +330,8 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
policy_cfg_cls = PreTrainedConfig.get_choice_class(policy_type)
return policy_cfg_cls(**kwargs)