Refactor configs to have env in seperate yaml + Fix training
This commit is contained in:
17
tests/test_datasets.py
Normal file
17
tests/test_datasets.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
|
||||
from .utils import init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name",
|
||||
[
|
||||
"simxarm",
|
||||
"pusht",
|
||||
],
|
||||
)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}"])
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
@@ -78,13 +78,13 @@ def test_pusht(from_pixels, pixels_only):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_name",
|
||||
"env_name",
|
||||
[
|
||||
"default",
|
||||
"simxarm",
|
||||
"pusht",
|
||||
],
|
||||
)
|
||||
def test_factory(config_name):
|
||||
cfg = init_config(config_name)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}"])
|
||||
env = make_env(cfg)
|
||||
check_env_specs(env)
|
||||
|
||||
@@ -6,12 +6,12 @@ from .utils import init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_name",
|
||||
"env_name",
|
||||
[
|
||||
"default",
|
||||
"simxarm",
|
||||
"pusht",
|
||||
],
|
||||
)
|
||||
def test_factory(config_name):
|
||||
cfg = init_config(config_name)
|
||||
def test_factory(env_name):
|
||||
cfg = init_config(overrides=[f"env={env_name}"])
|
||||
policy = make_policy(cfg)
|
||||
|
||||
@@ -4,8 +4,8 @@ from hydra import compose, initialize
|
||||
CONFIG_PATH = "../lerobot/configs"
|
||||
|
||||
|
||||
def init_config(config_name):
|
||||
def init_config(config_name="default", overrides=None):
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=CONFIG_PATH)
|
||||
cfg = compose(config_name=config_name)
|
||||
cfg = compose(config_name=config_name, overrides=overrides)
|
||||
return cfg
|
||||
|
||||
Reference in New Issue
Block a user