Fix tests
This commit is contained in:
@@ -8,7 +8,7 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
ImageWriter,
|
||||
AsyncImageWriter,
|
||||
image_array_to_image,
|
||||
safe_stop_image_writer,
|
||||
write_image,
|
||||
@@ -17,8 +17,8 @@ from lerobot.common.datasets.image_writer import (
|
||||
DUMMY_IMAGE = "test_image.png"
|
||||
|
||||
|
||||
def test_init_threading(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
|
||||
def test_init_threading():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 0
|
||||
assert writer.num_threads == 2
|
||||
@@ -30,8 +30,8 @@ def test_init_threading(tmp_path):
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_init_multiprocessing(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
def test_init_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 2
|
||||
assert writer.num_threads == 2
|
||||
@@ -43,35 +43,9 @@ def test_init_multiprocessing(tmp_path):
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_write_dir_created(tmp_path):
|
||||
write_dir = tmp_path / "non_existent_dir"
|
||||
assert not write_dir.exists()
|
||||
writer = ImageWriter(write_dir=write_dir)
|
||||
try:
|
||||
assert write_dir.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_get_image_file_path_and_episode_dir(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
try:
|
||||
episode_index = 1
|
||||
image_key = "test_key"
|
||||
frame_index = 10
|
||||
expected_episode_dir = tmp_path / f"{image_key}/episode_{episode_index:06d}"
|
||||
expected_path = expected_episode_dir / f"frame_{frame_index:06d}.png"
|
||||
image_file_path = writer.get_image_file_path(episode_index, image_key, frame_index)
|
||||
assert image_file_path == expected_path
|
||||
episode_dir = writer.get_episode_dir(episode_index, image_key)
|
||||
assert episode_dir == expected_episode_dir
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_zero_threads(tmp_path):
|
||||
def test_zero_threads():
|
||||
with pytest.raises(ValueError):
|
||||
ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=0)
|
||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||
|
||||
|
||||
def test_image_array_to_image_rgb(img_array_factory):
|
||||
@@ -148,7 +122,7 @@ def test_write_image_exception(tmp_path):
|
||||
|
||||
|
||||
def test_save_image_numpy(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
@@ -163,7 +137,7 @@ def test_save_image_numpy(tmp_path, img_array_factory):
|
||||
|
||||
|
||||
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
@@ -177,7 +151,7 @@ def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||
|
||||
|
||||
def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
@@ -193,7 +167,7 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
|
||||
|
||||
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
@@ -208,7 +182,7 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
|
||||
|
||||
def test_save_image_pil(tmp_path, img_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
@@ -223,7 +197,7 @@ def test_save_image_pil(tmp_path, img_factory):
|
||||
|
||||
|
||||
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
@@ -237,10 +211,10 @@ def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||
|
||||
|
||||
def test_save_image_invalid_data(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = "invalid data"
|
||||
fpath = writer.get_image_file_path(0, "test_key", 0)
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
writer.save_image(image_array, fpath)
|
||||
@@ -252,47 +226,47 @@ def test_save_image_invalid_data(tmp_path):
|
||||
|
||||
|
||||
def test_save_image_after_stop(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
image_array = img_array_factory()
|
||||
fpath = writer.get_image_file_path(0, "test_key", 0)
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
time.sleep(1)
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_stop(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=2)
|
||||
def test_stop():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_stop_multiprocessing(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
def test_stop_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(p.is_alive() for p in writer.processes)
|
||||
|
||||
|
||||
def test_multiple_stops(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
def test_multiple_stops():
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_multiple_stops_multiprocessing(tmp_path):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
def test_multiple_stops_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=0, num_threads=4)
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
|
||||
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
@@ -306,11 +280,11 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
||||
|
||||
|
||||
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path, num_processes=2, num_threads=2)
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory() for _ in range(num_images)]
|
||||
fpaths = [writer.get_image_file_path(0, "test_key", i) for i in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
@@ -324,7 +298,7 @@ def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||
|
||||
|
||||
def test_exception_handling(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
with (
|
||||
@@ -338,7 +312,7 @@ def test_exception_handling(tmp_path, img_array_factory):
|
||||
|
||||
|
||||
def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
formats = ["png", "jpeg", "bmp"]
|
||||
@@ -353,7 +327,7 @@ def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
def test_safe_stop_image_writer_decorator():
|
||||
class MockDataset:
|
||||
def __init__(self):
|
||||
self.image_writer = MagicMock(spec=ImageWriter)
|
||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
||||
|
||||
@safe_stop_image_writer
|
||||
def function_that_raises_exception(dataset=None):
|
||||
@@ -369,10 +343,10 @@ def test_safe_stop_image_writer_decorator():
|
||||
|
||||
|
||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||
writer = ImageWriter(write_dir=tmp_path)
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / "test_main_process_time.png"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
start_time = time.perf_counter()
|
||||
writer.save_image(image_tensor, fpath)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user