Merge branch 'main' into user/michel-aractingi/2024-11-27-port-hil-serl

This commit is contained in:
Michel Aractingi
2024-12-10 16:02:49 +01:00
27 changed files with 752 additions and 2031 deletions

View File

@@ -158,7 +158,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
# TODO(rcadene, aliberts): rethink this design
if robot_type == "aloha":
@@ -295,24 +295,12 @@ def test_resume_record(tmpdir, request, robot_type, mock):
dataset = record(**record_kwargs)
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
# init_dataset_return_value = {}
# def wrapped_init_dataset(*args, **kwargs):
# nonlocal init_dataset_return_value
# init_dataset_return_value = init_dataset(*args, **kwargs)
# return init_dataset_return_value
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
with pytest.raises(FileExistsError):
# Dataset already exists, but resume=False by default
record(**record_kwargs)
dataset = record(**record_kwargs, resume=True)
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
# assert (
# init_dataset_return_value["num_episodes"] == 2
# ), "`init_dataset` should load the previous episode"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])

View File

@@ -383,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.

View File

@@ -5,7 +5,7 @@ we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
@@ -330,7 +330,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
],
)
@pytest.mark.skip(
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")