Use HWC for images
This commit is contained in:
19
tests/fixtures/dataset_factories.py
vendored
19
tests/fixtures/dataset_factories.py
vendored
@@ -27,15 +27,6 @@ from tests.fixtures.defaults import (
|
||||
)
|
||||
|
||||
|
||||
def make_dummy_shapes(keys: list[str] | None = None, camera_keys: list[str] | None = None) -> dict:
|
||||
shapes = {}
|
||||
if keys:
|
||||
shapes.update({key: 10 for key in keys})
|
||||
if camera_keys:
|
||||
shapes.update({key: {"width": 100, "height": 70, "channels": 3} for key in camera_keys})
|
||||
return shapes
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
@@ -44,7 +35,7 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(width=100, height=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
@@ -52,7 +43,7 @@ def img_tensor_factory():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(width=100, height=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
@@ -68,8 +59,8 @@ def img_array_factory():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
def _create_img(width=100, height=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(width=width, height=height)
|
||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(height=height, width=width)
|
||||
return PIL.Image.fromarray(img_array)
|
||||
|
||||
return _create_img
|
||||
@@ -259,7 +250,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(width=ft["shapes"][0], height=ft["shapes"][1])
|
||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
|
||||
4
tests/fixtures/defaults.py
vendored
4
tests/fixtures/defaults.py
vendored
@@ -16,8 +16,8 @@ DUMMY_MOTOR_FEATURES = {
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
|
||||
"phone": {"shape": (640, 480, 3), "names": ["width", "height", "channels"], "info": None},
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
|
||||
@@ -265,7 +265,7 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(width=500, height=500) for _ in range(num_images)]
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
Reference in New Issue
Block a user