modified tests dirs

This commit is contained in:
Michel Aractingi
2024-09-02 08:04:56 +00:00
parent bbce0eaeaf
commit 3034272229
2 changed files with 2 additions and 4 deletions

View File

@@ -22,7 +22,6 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.utils import DEFAULT_CONFIG_PATH
@@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
dataloader = torch.utils.data.DataLoader(
dataset,

View File

@@ -37,7 +37,6 @@ from lerobot.common.policies.factory import (
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.train import make_optimizer_and_scheduler
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -214,7 +213,7 @@ def test_act_backbone_lr():
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
assert optimizer.param_groups[1]["lr"] == cfg.training.lr_backbone