Release cleanup (#132)

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Cadene <re.cadene@gmail.com>
This commit is contained in:
Simon Alibert
2024-05-06 03:03:14 +02:00
committed by GitHub
parent 6eaffbef1d
commit f5e76393eb
19 changed files with 312 additions and 237 deletions

View File

@@ -15,7 +15,7 @@ from tests.utils import require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be sucessfully imported — if they're installed — and that their
be successfully imported — if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"

View File

@@ -41,7 +41,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
image_keys = dataset.image_keys
camera_keys = dataset.camera_keys
item = dataset[0]
@@ -71,7 +71,7 @@ def test_factory(env_name, repo_id, policy_name):
else:
assert item[key].ndim == ndim, f"{key}"
if key in image_keys:
if key in camera_keys:
assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model
assert item[key].max() <= 1.0, f"{key}"

View File

@@ -46,7 +46,7 @@ def test_examples_3_and_2():
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.safetensors", "config.json", "config.yaml"]:
for file_name in ["model.safetensors", "config.json"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
@@ -58,16 +58,16 @@ def test_examples_3_and_2():
file_contents = _find_and_replace(
file_contents,
[
('pretrained_policy_name = "lerobot/diffusion_pusht"', ""),
("pretrained_policy_path = Path(snapshot_download(pretrained_policy_name))", ""),
('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('"eval.n_episodes=10"', '"eval.n_episodes=1"'),
('"eval.batch_size=10"', '"eval.batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("step += 1", "break"),
],
)
assert Path("outputs/train/example_pusht_diffusion").exists()
exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()