modified tests dirs
This commit is contained in:
@@ -22,7 +22,6 @@ from safetensors.torch import save_file
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, set_global_seed
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@@ -40,7 +39,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
optimizer, _ = policy.make_optimizer_and_scheduler(**cfg.training)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
|
||||
Reference in New Issue
Block a user