Fix tests

This commit is contained in:
Simon Alibert
2024-11-05 19:09:12 +01:00
parent aed9f4036a
commit f3630ad910
13 changed files with 437 additions and 496 deletions

View File

@@ -8,7 +8,7 @@ import pytest
from PIL import Image
from lerobot.common.datasets.image_writer import (
ImageWriter,
AsyncImageWriter,
image_array_to_image,
safe_stop_image_writer,
write_image,
@@ -17,8 +17,8 @@ from lerobot.common.datasets.image_writer import (
DUMMY_IMAGE = "test_image.png"
def test_init_threading(tmp_path):
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
def test_init_threading():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
try:
assert writer.num_processes == 0
assert writer.num_threads == 2
@@ -30,8 +30,8 @@ def test_init_threading(tmp_path):
writer.stop()
def test_init_multiprocessing(tmp_path):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
def test_init_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
assert writer.num_processes == 2
assert writer.num_threads == 2
@@ -43,35 +43,9 @@ def test_init_multiprocessing(tmp_path):
writer.stop()
def test_write_dir_created(tmp_path):
write_dir = tmp_path / "non_existent_dir"
assert not write_dir.exists()
writer = ImageWriter(write_dir=write_dir)
try:
assert write_dir.exists()
finally:
writer.stop()
def test_get_image_file_path_and_episode_dir(tmp_path):
writer = ImageWriter(write_dir=tmp_path)
try:
episode_index = 1
image_key = "test_key"
frame_index = 10
expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}"
expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png"
image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index)
assert image_file_path == expected_path
episode_dir = writer.get_episode_dir(episode_index, image_key)
assert episode_dir == expected_episode_dir
finally:
writer.stop()
def test_zero_threads(tmp_path):
def test_zero_threads():
with pytest.raises(ValueError):
ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0)
AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_image_rgb(img_array_factory):
@@ -148,7 +122,7 @@ def test_write_image_exception(tmp_path):
def test_save_image_numpy(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
@@ -163,7 +137,7 @@ def test_save_image_numpy(tmp_path, img_array_factory):
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
@@ -177,7 +151,7 @@ def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
def test_save_image_torch(tmp_path, img_tensor_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
@@ -193,7 +167,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
@@ -208,7 +182,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
def test_save_image_pil(tmp_path, img_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
@@ -223,7 +197,7 @@ def test_save_image_pil(tmp_path, img_factory):
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
@@ -237,10 +211,10 @@ def test_save_image_pil_multiprocessing(tmp_path, img_factory):
def test_save_image_invalid_data(tmp_path):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_array = "invalid data"
fpath = writer.get_image_file_path(0, "test_key", 0)
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
writer.save_image(image_array, fpath)
@@ -252,47 +226,47 @@ def test_save_image_invalid_data(tmp_path):
def test_save_image_after_stop(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
writer.stop()
image_array = img_array_factory()
fpath = writer.get_image_file_path(0, "test_key", 0)
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
time.sleep(1)
assert not fpath.exists()
def test_stop(tmp_path):
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
def test_stop():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
writer.stop()
assert not any(t.is_alive() for t in writer.threads)
def test_stop_multiprocessing(tmp_path):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
def test_stop_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
assert not any(p.is_alive() for p in writer.processes)
def test_multiple_stops(tmp_path):
writer = ImageWriter(write_dir=tmp_path)
def test_multiple_stops():
writer = AsyncImageWriter()
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_multiple_stops_multiprocessing(tmp_path):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
def test_multiple_stops_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_wait_until_done(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4)
writer = AsyncImageWriter(num_processes=0, num_threads=4)
try:
num_images = 100
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
@@ -306,11 +280,11 @@ def test_wait_until_done(tmp_path, img_array_factory):
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
num_images = 100
image_arrays = [img_array_factory() for _ in range(num_images)]
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
@@ -324,7 +298,7 @@ def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
def test_exception_handling(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
with (
@@ -338,7 +312,7 @@ def test_exception_handling(tmp_path, img_array_factory):
def test_with_different_image_formats(tmp_path, img_array_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
formats = ["png", "jpeg", "bmp"]
@@ -353,7 +327,7 @@ def test_with_different_image_formats(tmp_path, img_array_factory):
def test_safe_stop_image_writer_decorator():
class MockDataset:
def __init__(self):
self.image_writer = MagicMock(spec=ImageWriter)
self.image_writer = MagicMock(spec=AsyncImageWriter)
@safe_stop_image_writer
def function_that_raises_exception(dataset=None):
@@ -369,10 +343,10 @@ def test_safe_stop_image_writer_decorator():
def test_main_process_time(tmp_path, img_tensor_factory):
writer = ImageWriter(write_dir=tmp_path)
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / "test_main_process_time.png"
fpath = tmp_path / DUMMY_IMAGE
start_time = time.perf_counter()
writer.save_image(image_tensor, fpath)
end_time = time.perf_counter()