forked from tangger/lerobot
@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
env.step(action)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.training.lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
@@ -351,6 +352,7 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||
[
|
||||
@@ -381,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
include a report on what changed and how that affected the outputs.
|
||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||
add the policies you want to update the test artifacts for.
|
||||
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
should be updated.
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
|
||||
Reference in New Issue
Block a user