Improve dataset examples (#82)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-18 11:43:16 +02:00
committed by GitHub
parent d5c4b0c344
commit 0928afd37d
15 changed files with 274 additions and 165 deletions

View File

@@ -1,4 +1,5 @@
from pathlib import Path
import subprocess
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
@@ -8,23 +9,29 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
def _run_script(path):
subprocess.run(['python', path], check=True)
def test_example_1():
path = "examples/1_visualize_dataset.py"
with open(path, "r") as file:
file_contents = file.read()
exec(file_contents)
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
path = "examples/1_load_hugging_face_dataset.py"
_run_script(path)
assert Path("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4").exists()
def test_examples_3_and_2():
def test_example_2():
path = "examples/2_load_lerobot_dataset.py"
_run_script(path)
assert Path("outputs/examples/2_load_lerobot_dataset/episode_5.mp4").exists()
def test_examples_4_and_3():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
"""
path = "examples/3_train_policy.py"
path = "examples/4_train_policy.py"
with open(path, "r") as file:
file_contents = file.read()
@@ -46,7 +53,7 @@ def test_examples_3_and_2():
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
path = "examples/3_evaluate_pretrained_policy.py"
with open(path, "r") as file:
file_contents = file.read()