modified tests dirs
This commit is contained in:
@@ -22,7 +22,6 @@ from safetensors.torch import save_file
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
|
||||
@@ -37,7 +37,6 @@ from lerobot.common.policies.factory import (
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
@@ -214,7 +213,7 @@ def test_act_backbone_lr():
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone
|
||||
|
||||
Reference in New Issue
Block a user