Validate features during add_frame + Add 2D-to-5D + Add string (#720)
This commit is contained in:
@@ -9,10 +9,11 @@ from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
AsyncImageWriter,
|
||||
image_array_to_image,
|
||||
image_array_to_pil_image,
|
||||
safe_stop_image_writer,
|
||||
write_image,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_HWC
|
||||
|
||||
DUMMY_IMAGE = "test_image.png"
|
||||
|
||||
@@ -48,49 +49,62 @@ def test_zero_threads():
|
||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||
|
||||
|
||||
def test_image_array_to_image_rgb(img_array_factory):
|
||||
def test_image_array_to_pil_image_float_array_wrong_range_0_255():
|
||||
image = np.random.rand(*DUMMY_HWC) * 255
|
||||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
|
||||
image = np.random.rand(*DUMMY_HWC) * 2 - 1
|
||||
with pytest.raises(ValueError):
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_rgb(img_array_factory):
|
||||
img_array = img_array_factory(100, 100)
|
||||
result_image = image_array_to_image(img_array)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
def test_image_array_to_image_pytorch_format(img_array_factory):
|
||||
def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
||||
result_image = image_array_to_image(img_array)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: implement")
|
||||
def test_image_array_to_image_single_channel(img_array_factory):
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "L"
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
def test_image_array_to_image_float_array(img_array_factory):
|
||||
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
||||
img_array = img_array_factory(channels=4)
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_float_array(img_array_factory):
|
||||
img_array = img_array_factory(dtype=np.float32)
|
||||
result_image = image_array_to_image(img_array)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
|
||||
|
||||
def test_image_array_to_image_out_of_bounds_float():
|
||||
# Float array with values out of [0, 1]
|
||||
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
|
||||
result_image = image_array_to_image(img_array)
|
||||
def test_image_array_to_pil_image_uint8_array(img_array_factory):
|
||||
img_array = img_array_factory(dtype=np.float32)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
|
||||
|
||||
|
||||
def test_write_image_numpy(tmp_path, img_array_factory):
|
||||
|
||||
Reference in New Issue
Block a user