Move make_policy_config
This commit is contained in:
@@ -54,31 +54,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
if policy_type == "act":
|
||||
from .act import ACTConfig
|
||||
|
||||
return ACTConfig(**kwargs)
|
||||
elif policy_type == "diffusion":
|
||||
from .diffusion import DiffusionConfig
|
||||
|
||||
return DiffusionConfig(**kwargs)
|
||||
elif policy_type == "pi0":
|
||||
from .pi0 import PI0Config
|
||||
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "tdmpc":
|
||||
from .tdmpc import TDMPCConfig
|
||||
|
||||
return TDMPCConfig(**kwargs)
|
||||
elif policy_type == "vqbet":
|
||||
from .vqbet import VQBeTConfig
|
||||
|
||||
return VQBeTConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
cfg: PreTrainedConfig,
|
||||
ds_meta: LeRobotDatasetMetadata | None = None,
|
||||
|
||||
@@ -22,10 +22,10 @@ from safetensors.torch import save_file
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies import make_policy
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.utils import make_policy_config
|
||||
|
||||
|
||||
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
|
||||
@@ -40,12 +40,11 @@ from lerobot.common.datasets.utils import (
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
from tests.utils import make_policy_config, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -28,11 +28,11 @@ from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.act import ACTConfig
|
||||
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.common.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -41,7 +41,7 @@ from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
from tests.utils import DEVICE, make_policy_config, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -208,11 +208,10 @@ def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
"""
|
||||
|
||||
cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
|
||||
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
policy=ACTConfig(optimizer_lr=0.01, optimizer_lr_backbone=0.001),
|
||||
)
|
||||
cfg.validate() # Needed for auto-setting some parameters
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.policies import PreTrainedConfig
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
@@ -329,3 +330,8 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
policy_cfg_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
return policy_cfg_cls(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user