Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
assert cfg.training.lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
@@ -351,6 +352,7 @@ def test_normalize(insert_temporal_dim):
unnormalize(output_batch)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides, file_name_extra",
[
@@ -381,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.