forked from tangger/lerobot
WIP
This commit is contained in:
@@ -19,7 +19,7 @@ import torch
|
|||||||
from omegaconf import ListConfig, OmegaConf
|
from omegaconf import ListConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
from lerobot.common.datasets.transforms import make_image_transforms
|
from lerobot.common.datasets.transforms import get_image_transforms
|
||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(cfg):
|
def resolve_delta_timestamps(cfg):
|
||||||
@@ -72,7 +72,22 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||||||
|
|
||||||
resolve_delta_timestamps(cfg)
|
resolve_delta_timestamps(cfg)
|
||||||
|
|
||||||
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
image_transforms = None
|
||||||
|
if cfg.image_transforms.enable:
|
||||||
|
image_transforms = get_image_transforms(
|
||||||
|
brightness_weight=cfg.brightness.weight,
|
||||||
|
brightness_min_max=cfg.brightness.min_max,
|
||||||
|
contrast_weight=cfg.contrast.weight,
|
||||||
|
contrast_min_max=cfg.contrast.min_max,
|
||||||
|
saturation_weight=cfg.saturation.weight,
|
||||||
|
saturation_min_max=cfg.saturation.min_max,
|
||||||
|
hue_weight=cfg.hue.weight,
|
||||||
|
hue_min_max=cfg.hue.min_max,
|
||||||
|
sharpness_weight=cfg.sharpness.weight,
|
||||||
|
sharpness_min_max=cfg.sharpness.min_max,
|
||||||
|
max_num_transforms=cfg.max_num_transforms,
|
||||||
|
random_order=cfg.random_order,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(cfg.dataset_repo_id, str):
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
|
|||||||
@@ -98,26 +98,60 @@ class RangeRandomSharpness(Transform):
|
|||||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||||
|
|
||||||
|
|
||||||
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
def get_image_transforms(
|
||||||
transforms_list = [
|
brightness_weight: float = 1.0,
|
||||||
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
brightness_min_max: tuple[float, float] | None = None,
|
||||||
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
contrast_weight: float = 1.0,
|
||||||
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
|
contrast_min_max: tuple[float, float] | None = None,
|
||||||
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
|
saturation_weight: float = 1.0,
|
||||||
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
|
saturation_min_max: tuple[float, float] | None = None,
|
||||||
]
|
hue_weight: float = 1.0,
|
||||||
transforms_weights = [
|
hue_min_max: tuple[float, float] | None = None,
|
||||||
cfg.brightness.weight,
|
sharpness_weight: float = 1.0,
|
||||||
cfg.contrast.weight,
|
sharpness_min_max: tuple[float, float] | None = None,
|
||||||
cfg.saturation.weight,
|
max_num_transforms: int | None = None,
|
||||||
cfg.hue.weight,
|
random_order: bool = False,
|
||||||
cfg.sharpness.weight,
|
):
|
||||||
]
|
|
||||||
|
def check_value_error(name, weight, min_max):
|
||||||
|
if min_max is not None:
|
||||||
|
if len(min_max) != 2:
|
||||||
|
raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.")
|
||||||
|
if weight < 0.:
|
||||||
|
raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).")
|
||||||
|
|
||||||
transforms = RandomSubsetApply(
|
check_value_error("brightness", brightness_weight, brightness_min_max)
|
||||||
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
|
check_value_error("contrast", contrast_weight, contrast_min_max)
|
||||||
|
check_value_error("saturation", saturation_weight, saturation_min_max)
|
||||||
|
check_value_error("hue", hue_weight, hue_min_max)
|
||||||
|
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
|
||||||
|
|
||||||
|
weights = []
|
||||||
|
transforms = []
|
||||||
|
if brightness_min_max is not None:
|
||||||
|
weights.append(brightness_weight)
|
||||||
|
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||||
|
if contrast_min_max is not None:
|
||||||
|
weights.append(contrast_weight)
|
||||||
|
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||||
|
if saturation_min_max is not None:
|
||||||
|
weights.append(saturation_weight)
|
||||||
|
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||||
|
if hue_min_max is not None:
|
||||||
|
weights.append(hue_weight)
|
||||||
|
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||||
|
if sharpness_min_max is not None:
|
||||||
|
weights.append(sharpness_weight)
|
||||||
|
transforms.append(RangeRandomSharpness(**sharpness_min_max))
|
||||||
|
|
||||||
|
if max_num_transforms is None:
|
||||||
|
n_subset = len(transforms)
|
||||||
|
else:
|
||||||
|
n_subset = min(len(transforms), max_num_transforms)
|
||||||
|
|
||||||
|
final_transforms = RandomSubsetApply(
|
||||||
|
transforms, p=weights, n_subset=n_subset, random_order=random_order
|
||||||
)
|
)
|
||||||
|
|
||||||
# return transforms
|
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||||
# return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)])
|
return final_transforms
|
||||||
return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)])
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ image_transforms:
|
|||||||
# weight: This represents the multinomial probability (with no replacement)
|
# weight: This represents the multinomial probability (with no replacement)
|
||||||
# used for sampling the transform. If the sum of the weights is not 1,
|
# used for sampling the transform. If the sum of the weights is not 1,
|
||||||
# they will be normalized.
|
# they will be normalized.
|
||||||
# min/max: Lower & upper bound respectively used for sampling the transform's parameter
|
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||||
# (following uniform distribution) when it's applied.
|
# (following uniform distribution) when it's applied.
|
||||||
enable: false
|
enable: false
|
||||||
# This is the number of transforms (sampled from these below) that will be applied to each frame.
|
# This is the number of transforms (sampled from these below) that will be applied to each frame.
|
||||||
@@ -78,21 +78,16 @@ image_transforms:
|
|||||||
random_order: false
|
random_order: false
|
||||||
brightness:
|
brightness:
|
||||||
weight: 1
|
weight: 1
|
||||||
min: 0.8
|
min_max: [0.8, 1.2]
|
||||||
max: 1.2
|
|
||||||
contrast:
|
contrast:
|
||||||
weight: 1
|
weight: 1
|
||||||
min: 0.8
|
min_max: [0.8, 1.2]
|
||||||
max: 1.2
|
|
||||||
saturation:
|
saturation:
|
||||||
weight: 1
|
weight: 1
|
||||||
min: 0.5
|
min_max: [0.5, 1.5]
|
||||||
max: 1.5
|
|
||||||
hue:
|
hue:
|
||||||
weight: 1
|
weight: 1
|
||||||
min: -0.05
|
min_max: [-0.05, 0.05]
|
||||||
max: 0.05
|
|
||||||
sharpness:
|
sharpness:
|
||||||
weight: 1
|
weight: 1
|
||||||
min: 0.8
|
min_max: [0.8, 1.2]
|
||||||
max: 1.2
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -8,144 +9,245 @@ from PIL import Image
|
|||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
|
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, get_image_transforms
|
||||||
from lerobot.common.utils.utils import seeded_context
|
from lerobot.common.utils.utils import seeded_context
|
||||||
|
|
||||||
|
|
||||||
class TestRandomSubsetApply:
|
# test_make_image_transforms
|
||||||
@pytest.fixture(autouse=True)
|
# -
|
||||||
def setup(self):
|
|
||||||
self.jitters = [
|
|
||||||
v2.ColorJitter(brightness=0.5),
|
|
||||||
v2.ColorJitter(contrast=0.5),
|
|
||||||
v2.ColorJitter(saturation=0.5),
|
|
||||||
]
|
|
||||||
self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
|
||||||
self.img = torch.rand(3, 224, 224)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
# test backward compatibility torchvision
|
||||||
def test_random_choice(self, p):
|
# - save artifacts
|
||||||
random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
|
|
||||||
output = random_choice(self.img)
|
|
||||||
|
|
||||||
p_horz, _ = p
|
# test backward compatibility default yaml (enable false, enable true)
|
||||||
if p_horz:
|
# - save artifacts
|
||||||
torch.testing.assert_close(output, F.horizontal_flip(self.img))
|
|
||||||
else:
|
|
||||||
torch.testing.assert_close(output, F.vertical_flip(self.img))
|
|
||||||
|
|
||||||
def test_transform_all(self):
|
|
||||||
transform = RandomSubsetApply(self.jitters)
|
|
||||||
output = transform(self.img)
|
|
||||||
assert output.shape == self.img.shape
|
|
||||||
|
|
||||||
def test_transform_subset(self):
|
|
||||||
transform = RandomSubsetApply(self.jitters, n_subset=2)
|
|
||||||
output = transform(self.img)
|
|
||||||
assert output.shape == self.img.shape
|
|
||||||
|
|
||||||
def test_random_order(self):
|
|
||||||
random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
|
||||||
# We can't really check whether the transforms are actually applied in random order. However,
|
|
||||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
|
||||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
|
||||||
actual = random_order(self.img)
|
|
||||||
expected = v2.Compose(self.flips)(self.img)
|
|
||||||
torch.testing.assert_close(actual, expected)
|
|
||||||
|
|
||||||
def test_probability_length_mismatch(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
RandomSubsetApply(self.jitters, p=[0.5, 0.5])
|
|
||||||
|
|
||||||
def test_invalid_n_subset(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
RandomSubsetApply(self.jitters, n_subset=5)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRangeRandomSharpness:
|
def test_get_image_transforms_no_transform():
|
||||||
@pytest.fixture(autouse=True)
|
get_image_transforms()
|
||||||
def setup(self):
|
get_image_transforms(sharpness_weight=0.0)
|
||||||
self.img = torch.rand(3, 224, 224)
|
get_image_transforms(max_num_transforms=0)
|
||||||
|
|
||||||
def test_valid_range(self):
|
|
||||||
transform = RangeRandomSharpness(0.1, 2.0)
|
|
||||||
output = transform(self.img)
|
|
||||||
assert output.shape == self.img.shape
|
|
||||||
|
|
||||||
def test_invalid_range_min_negative(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
RangeRandomSharpness(-0.1, 2.0)
|
|
||||||
|
|
||||||
def test_invalid_range_max_smaller(self):
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
RangeRandomSharpness(2.0, 0.1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestMakeImageTransforms:
|
@pytest.fixture
|
||||||
@pytest.fixture(autouse=True)
|
def img():
|
||||||
def setup(self):
|
# dataset = LeRobotDataset("lerobot/pusht")
|
||||||
"""Seed should be the same as the one that was used to generate artifacts"""
|
# item = dataset[0]
|
||||||
self.config = {
|
# return item["observation.image"]
|
||||||
"enable": True,
|
path = "tests/data/save_image_transforms/original_frame.png"
|
||||||
"max_num_transforms": 1,
|
img_chw = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||||
"random_order": False,
|
return img_chw
|
||||||
"brightness": {"weight": 0, "min": 2.0, "max": 2.0},
|
|
||||||
"contrast": {
|
|
||||||
"weight": 0,
|
|
||||||
"min": 2.0,
|
|
||||||
"max": 2.0,
|
|
||||||
},
|
|
||||||
"saturation": {
|
|
||||||
"weight": 0,
|
|
||||||
"min": 2.0,
|
|
||||||
"max": 2.0,
|
|
||||||
},
|
|
||||||
"hue": {
|
|
||||||
"weight": 0,
|
|
||||||
"min": 0.5,
|
|
||||||
"max": 0.5,
|
|
||||||
},
|
|
||||||
"sharpness": {
|
|
||||||
"weight": 0,
|
|
||||||
"min": 2.0,
|
|
||||||
"max": 2.0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
self.path = Path("tests/data/save_image_transforms")
|
|
||||||
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
|
|
||||||
self.transforms = {
|
|
||||||
"brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
|
|
||||||
"contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
|
|
||||||
"saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
|
|
||||||
"hue": v2.ColorJitter(hue=(0.5, 0.5)),
|
|
||||||
"sharpness": RangeRandomSharpness(2.0, 2.0),
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
def test_get_image_transforms_brightness(img):
|
||||||
def load_png_to_tensor(path: Path):
|
brightness_min_max = (0.5, 0.5)
|
||||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
tf_actual = get_image_transforms(brightness_weight=1., brightness_min_max=brightness_min_max)
|
||||||
|
tf_expected = v2.ColorJitter(brightness=brightness_min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_get_image_transforms_contrast(img):
|
||||||
"transform_key, seed",
|
contrast_min_max = (0.5, 0.5)
|
||||||
[
|
tf_actual = get_image_transforms(contrast_weight=1., contrast_min_max=contrast_min_max)
|
||||||
("brightness", 1336),
|
tf_expected = v2.ColorJitter(contrast=contrast_min_max)
|
||||||
("contrast", 1336),
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
("saturation", 1336),
|
|
||||||
("hue", 1336),
|
def test_get_image_transforms_saturation(img):
|
||||||
("sharpness", 1336),
|
saturation_min_max = (0.5, 0.5)
|
||||||
],
|
tf_actual = get_image_transforms(saturation_weight=1., saturation_min_max=saturation_min_max)
|
||||||
|
tf_expected = v2.ColorJitter(saturation=saturation_min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
def test_get_image_transforms_hue(img):
|
||||||
|
hue_min_max = (0.5, 0.5)
|
||||||
|
tf_actual = get_image_transforms(hue_weight=1., hue_min_max=hue_min_max)
|
||||||
|
tf_expected = v2.ColorJitter(hue=hue_min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
def test_get_image_transforms_sharpness(img):
|
||||||
|
sharpness_min_max = (0.5, 0.5)
|
||||||
|
tf_actual = get_image_transforms(sharpness_weight=1., sharpness_min_max=sharpness_min_max)
|
||||||
|
tf_expected = RangeRandomSharpness(**sharpness_min_max)
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
|
def test_get_image_transforms_max_num_transforms(img):
|
||||||
|
tf_actual = get_image_transforms(
|
||||||
|
saturation_min_max=(0.5, 0.5),
|
||||||
|
constrast_min_max=(0.5, 0.5),
|
||||||
|
saturation_min_max=(0.5, 0.5),
|
||||||
|
hue_min_max=(0.5, 0.5),
|
||||||
|
sharpness_min_max=(0.5, 0.5),
|
||||||
|
random_order=False,
|
||||||
)
|
)
|
||||||
def test_single_transform(self, transform_key, seed):
|
tf_expected = v2.Compose([
|
||||||
config = self.config
|
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||||
config[transform_key]["weight"] = 1
|
v2.ColorJitter(contrast=(0.5, 0.5)),
|
||||||
cfg = OmegaConf.create(config)
|
v2.ColorJitter(saturation=(0.5, 0.5)),
|
||||||
|
v2.ColorJitter(hue=(0.5, 0.5)),
|
||||||
|
RangeRandomSharpness(sharpness=(0.5, 0.5)),
|
||||||
|
])
|
||||||
|
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||||
|
|
||||||
actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
|
|
||||||
with seeded_context(1336):
|
|
||||||
actual = actual_t(self.original_frame)
|
|
||||||
|
|
||||||
expected_t = self.transforms[transform_key]
|
def test_get_image_transforms_random_order(img):
|
||||||
with seeded_context(1336):
|
out_imgs = []
|
||||||
expected = expected_t(self.original_frame)
|
with seeded_context(1337):
|
||||||
|
for _ in range(20):
|
||||||
|
tf = get_image_transforms(
|
||||||
|
saturation_min_max=(0.5, 0.5),
|
||||||
|
constrast_min_max=(0.5, 0.5),
|
||||||
|
saturation_min_max=(0.5, 0.5),
|
||||||
|
hue_min_max=(0.5, 0.5),
|
||||||
|
sharpness_min_max=(0.5, 0.5),
|
||||||
|
random_order=False,
|
||||||
|
)
|
||||||
|
out_imgs.append(tf(img))
|
||||||
|
|
||||||
|
for i in range(1,10):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||||
|
|
||||||
torch.testing.assert_close(actual, expected)
|
|
||||||
|
|
||||||
|
def test_backward_compatibility_torchvision():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_backward_compatibility_default_yaml():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# class TestRandomSubsetApply:
|
||||||
|
# @pytest.fixture(autouse=True)
|
||||||
|
# def setup(self):
|
||||||
|
# self.jitters = [
|
||||||
|
# v2.ColorJitter(brightness=0.5),
|
||||||
|
# v2.ColorJitter(contrast=0.5),
|
||||||
|
# v2.ColorJitter(saturation=0.5),
|
||||||
|
# ]
|
||||||
|
# self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||||
|
# self.img = torch.rand(3, 224, 224)
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||||
|
# def test_random_choice(self, p):
|
||||||
|
# random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
|
||||||
|
# output = random_choice(self.img)
|
||||||
|
|
||||||
|
# p_horz, _ = p
|
||||||
|
# if p_horz:
|
||||||
|
# torch.testing.assert_close(output, F.horizontal_flip(self.img))
|
||||||
|
# else:
|
||||||
|
# torch.testing.assert_close(output, F.vertical_flip(self.img))
|
||||||
|
|
||||||
|
# def test_transform_all(self):
|
||||||
|
# transform = RandomSubsetApply(self.jitters)
|
||||||
|
# output = transform(self.img)
|
||||||
|
# assert output.shape == self.img.shape
|
||||||
|
|
||||||
|
# def test_transform_subset(self):
|
||||||
|
# transform = RandomSubsetApply(self.jitters, n_subset=2)
|
||||||
|
# output = transform(self.img)
|
||||||
|
# assert output.shape == self.img.shape
|
||||||
|
|
||||||
|
# def test_random_order(self):
|
||||||
|
# random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||||
|
# # We can't really check whether the transforms are actually applied in random order. However,
|
||||||
|
# # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||||
|
# # applies them in random order, we can use a fixed order to compute the expected value.
|
||||||
|
# actual = random_order(self.img)
|
||||||
|
# expected = v2.Compose(self.flips)(self.img)
|
||||||
|
# torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
# def test_probability_length_mismatch(self):
|
||||||
|
# with pytest.raises(ValueError):
|
||||||
|
# RandomSubsetApply(self.jitters, p=[0.5, 0.5])
|
||||||
|
|
||||||
|
# def test_invalid_n_subset(self):
|
||||||
|
# with pytest.raises(ValueError):
|
||||||
|
# RandomSubsetApply(self.jitters, n_subset=5)
|
||||||
|
|
||||||
|
|
||||||
|
# class TestRangeRandomSharpness:
|
||||||
|
# @pytest.fixture(autouse=True)
|
||||||
|
# def setup(self):
|
||||||
|
# self.img = torch.rand(3, 224, 224)
|
||||||
|
|
||||||
|
# def test_valid_range(self):
|
||||||
|
# transform = RangeRandomSharpness(0.1, 2.0)
|
||||||
|
# output = transform(self.img)
|
||||||
|
# assert output.shape == self.img.shape
|
||||||
|
|
||||||
|
# def test_invalid_range_min_negative(self):
|
||||||
|
# with pytest.raises(ValueError):
|
||||||
|
# RangeRandomSharpness(-0.1, 2.0)
|
||||||
|
|
||||||
|
# def test_invalid_range_max_smaller(self):
|
||||||
|
# with pytest.raises(ValueError):
|
||||||
|
# RangeRandomSharpness(2.0, 0.1)
|
||||||
|
|
||||||
|
|
||||||
|
# class TestMakeImageTransforms:
|
||||||
|
# @pytest.fixture(autouse=True)
|
||||||
|
# def setup(self):
|
||||||
|
# """Seed should be the same as the one that was used to generate artifacts"""
|
||||||
|
# self.config = {
|
||||||
|
# "enable": True,
|
||||||
|
# "max_num_transforms": 1,
|
||||||
|
# "random_order": False,
|
||||||
|
# "brightness": {"weight": 0, "min": 2.0, "max": 2.0},
|
||||||
|
# "contrast": {
|
||||||
|
# "weight": 0,
|
||||||
|
# "min": 2.0,
|
||||||
|
# "max": 2.0,
|
||||||
|
# },
|
||||||
|
# "saturation": {
|
||||||
|
# "weight": 0,
|
||||||
|
# "min": 2.0,
|
||||||
|
# "max": 2.0,
|
||||||
|
# },
|
||||||
|
# "hue": {
|
||||||
|
# "weight": 0,
|
||||||
|
# "min": 0.5,
|
||||||
|
# "max": 0.5,
|
||||||
|
# },
|
||||||
|
# "sharpness": {
|
||||||
|
# "weight": 0,
|
||||||
|
# "min": 2.0,
|
||||||
|
# "max": 2.0,
|
||||||
|
# },
|
||||||
|
# }
|
||||||
|
# self.path = Path("tests/data/save_image_transforms")
|
||||||
|
# self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
|
||||||
|
# self.transforms = {
|
||||||
|
# "brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
|
||||||
|
# "contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
|
||||||
|
# "saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
|
||||||
|
# "hue": v2.ColorJitter(hue=(0.5, 0.5)),
|
||||||
|
# "sharpness": RangeRandomSharpness(2.0, 2.0),
|
||||||
|
# }
|
||||||
|
|
||||||
|
# @staticmethod
|
||||||
|
# def load_png_to_tensor(path: Path):
|
||||||
|
# return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize(
|
||||||
|
# "transform_key, seed",
|
||||||
|
# [
|
||||||
|
# ("brightness", 1336),
|
||||||
|
# ("contrast", 1336),
|
||||||
|
# ("saturation", 1336),
|
||||||
|
# ("hue", 1336),
|
||||||
|
# ("sharpness", 1336),
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
# def test_single_transform(self, transform_key, seed):
|
||||||
|
# config = self.config
|
||||||
|
# config[transform_key]["weight"] = 1
|
||||||
|
# cfg = OmegaConf.create(config)
|
||||||
|
|
||||||
|
# actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
|
||||||
|
# with seeded_context(1336):
|
||||||
|
# actual = actual_t(self.original_frame)
|
||||||
|
|
||||||
|
# expected_t = self.transforms[transform_key]
|
||||||
|
# with seeded_context(1336):
|
||||||
|
# expected = expected_t(self.original_frame)
|
||||||
|
|
||||||
|
# torch.testing.assert_close(actual, expected)
|
||||||
|
|||||||
Reference in New Issue
Block a user