diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 5236b7ae..7287ed73 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -22,7 +22,6 @@ from safetensors.torch import save_file from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import init_hydra_config, set_global_seed -from lerobot.scripts.train import make_optimizer_and_scheduler from tests.utils import DEFAULT_CONFIG_PATH @@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): dataset = make_dataset(cfg) policy = make_policy(cfg, dataset_stats=dataset.stats) policy.train() - optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training) dataloader = torch.utils.data.DataLoader( dataset, diff --git a/tests/test_policies.py b/tests/test_policies.py index d90f0071..69261661 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -37,7 +37,6 @@ from lerobot.common.policies.factory import ( from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.policy_protocol import Policy from lerobot.common.utils.utils import init_hydra_config, seeded_context -from lerobot.scripts.train import make_optimizer_and_scheduler from tests.scripts.save_policy_to_safetensors import get_policy_stats from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel @@ -214,7 +213,7 @@ def test_act_backbone_lr(): dataset = make_dataset(cfg) policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats) - optimizer, _ = make_optimizer_and_scheduler(cfg, policy) + optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr"] == cfg.training.lr assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone