finish examples 2 and 3

This commit is contained in:
Alexander Soare
2024-03-26 16:13:40 +00:00
parent cb6d1e0871
commit 1ed0110900
10 changed files with 196 additions and 42 deletions

View File

@@ -1,19 +1,56 @@
import pytest
from pathlib import Path
@pytest.mark.parametrize(
"path",
[
"examples/1_visualize_dataset.py",
"examples/2_evaluate_pretrained_policy.py",
"examples/3_train_policy.py",
],
)
def test_example(path):
with open(path, 'r') as file:
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
for f, r in zip(finds, replaces):
assert f in text
text = text.replace(f, r)
return text
def test_example_1():
path = "examples/1_visualize_dataset.py"
with open(path, "r") as file:
file_contents = file.read()
exec(file_contents)
if path == "examples/1_visualize_dataset.py":
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
def test_examples_3_and_2():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
"""
path = "examples/3_train_policy.py"
with open(path, "r") as file:
file_contents = file.read()
# Do less steps and use CPU.
file_contents = _find_and_replace(
file_contents,
['"offline_steps=5000"', '"device=cuda"'],
['"offline_steps=1"', '"device=cpu"'],
)
exec(file_contents)
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
with open(path, "r") as file:
file_contents = file.read()
# Do less evals and use CPU.
file_contents = _find_and_replace(
file_contents,
['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'],
['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'],
)
assert Path(f"outputs/train/example_pusht_diffusion").exists()