Add img and img_tensor factories

This commit is contained in:
Simon Alibert
2024-11-02 13:06:38 +01:00
parent 293bdc7f67
commit 375abd3020
2 changed files with 52 additions and 32 deletions

View File

@@ -4,7 +4,9 @@ from unittest.mock import patch
import datasets
import numpy as np
import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import (
@@ -37,6 +39,14 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
return torch.rand((3, height, width), dtype=torch.float32)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(width=100, height=100) -> np.ndarray:
@@ -45,6 +55,15 @@ def img_array_factory():
return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(width=100, height=100) -> PIL.Image.Image:
img_array = img_array_factory(width=width, height=height)
return PIL.Image.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session")
def info_factory():
def _create_info(