forked from tangger/lerobot
Refactor train, eval_policy, logger, Add diffusion.yaml (WIP)
This commit is contained in:
@@ -6,12 +6,19 @@ from .utils import init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name",
|
||||
"env_name,policy_name",
|
||||
[
|
||||
"simxarm",
|
||||
"pusht",
|
||||
("simxarm", "tdmpc"),
|
||||
("pusht", "tdmpc"),
|
||||
("simxarm", "diffusion"),
|
||||
("pusht", "diffusion"),
|
||||
],
|
||||
)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}"])
|
||||
def test_factory(env_name, policy_name):
|
||||
cfg = init_config(
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
]
|
||||
)
|
||||
policy = make_policy(cfg)
|
||||
|
||||
Reference in New Issue
Block a user