This commit is contained in:
Simon Alibert
2024-06-06 15:23:49 +00:00
parent bdc0ebd36a
commit a86f387554
3 changed files with 102 additions and 8 deletions

View File

@@ -1,9 +1,17 @@
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
import pytest
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms
from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH
class TestRandomSubsetApply:
@@ -76,5 +84,84 @@ class TestRangeRandomSharpness:
class TestMakeTransforms:
...
# TODO
@pytest.fixture(autouse=True)
def setup(self):
"""Seed should be the same as the one that was used to generate artifacts"""
self.config = {
"enable": True,
"max_num_transforms": 1,
"random_order": False,
"brightness": {
"weight": 0,
"min": 0.0,
"max": 2.0
},
"contrast": {
"weight": 0,
"min": 0.0,
"max": 2.0,
},
"saturation": {
"weight": 0,
"min": 0.0,
"max": 2.0,
},
"hue": {
"weight": 0,
"min": -0.5,
"max": 0.5,
},
"sharpness": {
"weight": 0,
"min": 0.0,
"max": 2.0,
},
}
self.path = Path("tests/data/save_image_transforms")
# self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors")
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
# self.original_frame = self.expected_frames["original_frame"]
self.transforms = {
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
"sharpness": RangeRandomSharpness(0.0, 2.0),
}
@staticmethod
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1)
@pytest.mark.parametrize(
"transform_key, seed",
[
("brightness", 1336),
("contrast", 1336),
("saturation", 1336),
("hue", 1336),
("sharpness", 1336),
]
)
def test_single_transform(self, transform_key, seed):
config = self.config
config[transform_key]["weight"] = 1
cfg = OmegaConf.create(config)
transform = make_transforms(cfg, to_dtype=torch.uint8)
# expected_t = self.transforms[transform_key]
with seeded_context(seed):
actual = transform(self.original_frame)
# torch.manual_seed(42)
# actual = actual_t(self.original_frame)
# torch.manual_seed(42)
# expected = expected_t(self.original_frame)
# with seeded_context(1336):
# expected = expected_t(self.original_frame)
expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png")
# # expected = self.expected_frames[transform_key]
to_pil = v2.ToPILImage()
to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100)
torch.testing.assert_close(actual, expected)