forked from tangger/lerobot
Add img and img_tensor factories
This commit is contained in:
19
tests/fixtures/dataset_factories.py
vendored
19
tests/fixtures/dataset_factories.py
vendored
@@ -4,7 +4,9 @@ from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
@@ -37,6 +39,14 @@ def get_task_index(tasks_dicts: dict, task: str) -> int:
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(width=100, height=100) -> torch.Tensor:
|
||||
return torch.rand((3, height, width), dtype=torch.float32)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(width=100, height=100) -> np.ndarray:
|
||||
@@ -45,6 +55,15 @@ def img_array_factory():
|
||||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
def _create_img(width=100, height=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(width=width, height=height)
|
||||
return PIL.Image.Image.fromarray(img_array)
|
||||
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory():
|
||||
def _create_info(
|
||||
|
||||
Reference in New Issue
Block a user