Compare commits

...

40 Commits

Author SHA1 Message Date
Remi Cadene
ca81b1d6f4 WIP 2024-06-08 12:10:04 +02:00
Simon Alibert
e52942a200 fix TestMakeImageTransforms 2024-06-07 17:23:54 +02:00
Simon Alibert
b60810a8b6 Update config doc 2024-06-07 10:55:18 +00:00
Simon Alibert
faacb36271 transform -> image_transforms 2024-06-06 16:53:37 +00:00
Simon Alibert
c45dd8f848 rename to image_transforms 2024-06-06 16:50:22 +00:00
Simon Alibert
a86f387554 WIP 2024-06-06 15:24:31 +00:00
Marina Barannikov
bdc0ebd36a Updated default transform parameters 2024-06-06 13:50:48 +00:00
Marina Barannikov
19f4a6568d Updated default image_transform parameters 2024-06-06 13:37:58 +00:00
Marina Barannikov
9552a4f010 Added clarification comments 2024-06-06 09:24:58 +00:00
Marina Barannikov
d657139828 Updated comments 2024-06-06 09:23:39 +00:00
Simon Alibert
b1714803a3 Disable image_transform by default 2024-06-06 08:39:52 +00:00
Simon Alibert
5d55b19cbd Fix tests 2024-06-05 16:47:52 +00:00
Simon Alibert
641d349df4 Add save_image_transforms.py & artifacts 2024-06-05 16:30:47 +00:00
Simon Alibert
e444b0d529 Add first tests 2024-06-05 16:29:54 +00:00
Marina Barannikov
8237ed9aa4 Updated visualize script 2024-06-05 16:01:37 +00:00
Simon Alibert
82e32f1fcd Fix RandomSubsetApply weighted sampling 2024-06-05 14:19:37 +00:00
Marina Barannikov
644e77e413 Renamed scripts 2024-06-05 13:35:41 +00:00
Marina Barannikov
1b1bbb1632 Minor formatting 2024-06-05 13:31:40 +00:00
Marina Barannikov
0fb3dd745b Implented visualize_image_transforms script 2024-06-05 13:30:32 +00:00
Marina Barannikov
4dbc1adb0d Updated show_transform to match config 2024-06-05 12:32:53 +00:00
Simon Alibert
ceb95592af Remove prints 2024-06-05 12:21:00 +00:00
Simon Alibert
6509c3f6d4 Implement RandomSubsetApply features 2024-06-05 12:15:36 +00:00
Marina Barannikov
8b134725d5 Merge branch 'huggingface:main' into 2024_05_30_add_data_augmentation 2024-06-05 13:56:47 +02:00
Marina Barannikov
a544949ebe Added example of torchvision image augmentation on LeRobotDataset 2024-06-05 10:53:18 +00:00
Simon Alibert
fdf56e7a62 Redesign config 2024-06-05 09:49:31 +00:00
Simon Alibert
443b06b412 refactor show_image_transforms 2024-06-05 09:34:39 +00:00
Marina Barannikov
22bd1f0669 Updated formatting 2024-06-04 12:06:36 +00:00
Marina Barannikov
31e3c82386 Merge remote-tracking branch 'refs/remotes/origin/2024_05_30_add_data_augmentation' into 2024_05_30_add_data_augmentation 2024-06-04 12:00:46 +00:00
Marina Barannikov
5eea2542d9 Added visualisations for image augmentation 2024-06-04 11:57:45 +00:00
Marina Barannikov
42f9cc9c2a Updated transforms arguments 2024-06-04 11:14:54 +00:00
Marina Barannikov
66629a956d Updated config to match transforms 2024-06-04 11:09:23 +00:00
Simon Alibert
7be2c35c0a Merge branch 'huggingface:main' into 2024_05_30_add_data_augmentation 2024-06-04 12:28:06 +02:00
Marina Barannikov
14291171cc Updated default.yaml 2024-06-03 17:21:16 +00:00
Marina Barannikov
cc4b3bd8e7 Updated default.yaml 2024-06-03 17:18:33 +00:00
Simon Alibert
602ea9844b Add RandomSubsetApply 2024-06-03 17:15:37 +00:00
Marina Barannikov
9f8415fa83 Added clarification comments 2024-06-03 14:18:08 +00:00
Marina Barannikov
212a5ab29b Updated implementation on MultiLeRobotDataset 2024-05-31 16:28:08 +00:00
Marina Barannikov
c4870e5892 Added data augmentation feature to MultiLeRobotDataset 2024-05-31 15:42:31 +00:00
Marina Barannikov
20a3715469 Merge remote-tracking branch 'origin/main' into 2024_05_30_add_data_augmentation 2024-05-31 14:50:31 +00:00
marina.barannikov@huggingface.co
65e46a49e1 Implemented data augmentation with LeRobot class 2024-05-31 14:16:38 +00:00
9 changed files with 622 additions and 12 deletions

View File

@@ -0,0 +1,51 @@
"""
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
"""
from pathlib import Path
from torchvision.transforms import ToPILImage, v2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
to_pil = ToPILImage()
# Create a directory to store output images
output_dir = Path("outputs/image_transforms")
output_dir.mkdir(parents=True, exist_ok=True)
repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, image_transforms=None)
# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
# Get the frame corresponding to the first camera
frame = dataset[first_idx][dataset.camera_keys[0]]
# Save the original frame
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
# Define the transformations
transforms = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 1.5)),
v2.ColorJitter(contrast=(0.5, 1.5)),
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
]
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
# Save the transformed frame
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")

View File

@@ -19,6 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
def resolve_delta_timestamps(cfg):
@@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
# TODO(rcadene): add data augmentations
image_transforms = None
if cfg.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
if cfg.get("override_dataset_stats"):

View File

@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
@@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s,
)
if self.transform is not None:
item = self.transform(item)
if self.image_transforms is not None:
for cam in self.camera_keys:
item[cam] = self.image_transforms(item[cam])
return item
@@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)
@@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.version = version
obj.root = root
obj.split = split
obj.transform = transform
obj.image_transforms = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
@@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
image_transforms=image_transforms,
)
for repo_id in repo_ids
]
@@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
@@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f" Transformations: {self.image_transforms},\n"
f")"
)

View File

@@ -0,0 +1,157 @@
from typing import Any, Callable, Dict, Sequence
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
class RandomSubsetApply(Transform):
"""
Apply a random subset of N transformations from a list of transformations.
Args:
transforms (sequence or torch.nn.Module): list of transformations
p (list of floats or None, optional): probability of each transform being picked.
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
(default), all transforms have the same probability.
n_subset (int or None): number of transformations to apply. If ``None``,
all transforms are applied.
random_order (bool): apply transformations in a random order
"""
def __init__(
self,
transforms: Sequence[Callable],
p: list[float] | None = None,
n_subset: int | None = None,
random_order: bool = False,
) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
if p is None:
p = [1] * len(transforms)
elif len(p) != len(transforms):
raise ValueError(
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
)
if n_subset is None:
n_subset = len(transforms)
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (0 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
self.transforms = transforms
total = sum(p)
self.p = [prob / total for prob in p]
self.n_subset = n_subset
self.random_order = random_order
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
if not self.random_order:
selected_indices = selected_indices.sort().values
selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs
def extra_repr(self) -> str:
return (
f"transforms={self.transforms}, "
f"p={self.p}, "
f"n_subset={self.n_subset}, "
f"random_order={self.random_order}"
)
class RangeRandomSharpness(Transform):
"""Similar to v2.RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
each time in [range_min, range_max].
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
"""
def __init__(self, range_min: float, range_max) -> None:
super().__init__()
self.range_min, self.range_max = self._check_input(range_min, range_max)
def _check_input(self, range_min, range_max):
if range_min < 0:
raise ValueError("range_min must be non negative.")
if range_min > range_max:
raise ValueError("range_max must greater or equal to range_min")
return range_min, range_max
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1).item()
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value_error(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.")
if weight < 0.:
raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).")
check_value_error("brightness", brightness_weight, brightness_min_max)
check_value_error("contrast", contrast_weight, contrast_min_max)
check_value_error("saturation", saturation_weight, saturation_min_max)
check_value_error("hue", hue_weight, hue_min_max)
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
weights = []
transforms = []
if brightness_min_max is not None:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None:
weights.append(sharpness_weight)
transforms.append(RangeRandomSharpness(**sharpness_min_max))
if max_num_transforms is None:
n_subset = len(transforms)
else:
n_subset = min(len(transforms), max_num_transforms)
final_transforms = RandomSubsetApply(
transforms, p=weights, n_subset=n_subset, random_order=random_order
)
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return final_transforms

View File

@@ -57,3 +57,37 @@ wandb:
disable_artifact: false
project: lerobot
notes: ""
image_transforms:
# These transforms are all using standard torchvision.transforms.v2
# You can find out how these transformations affect images here:
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
# We use a custom RandomSubsetApply container to sample them.
# For each transform, the following parameters are available:
# weight: This represents the multinomial probability (with no replacement)
# used for sampling the transform. If the sum of the weights is not 1,
# they will be normalized.
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
# (following uniform distribution) when it's applied.
enable: false
# This is the number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [0, number of available transforms].
max_num_transforms: 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: false
brightness:
weight: 1
min_max: [0.8, 1.2]
contrast:
weight: 1
min_max: [0.8, 1.2]
saturation:
weight: 1
min_max: [0.5, 1.5]
hue:
weight: 1
min_max: [-0.05, 0.05]
sharpness:
weight: 1
min_max: [0.8, 1.2]

View File

@@ -0,0 +1,46 @@
from pathlib import Path
import hydra
from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_image_transforms
to_pil = ToPILImage()
def main(cfg, output_dir=Path("outputs/image_transforms")):
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir.mkdir(parents=True, exist_ok=True)
# Get first frame of 1st episode
first_idx = dataset.episode_data_index["from"][0].item()
frame = dataset[first_idx][dataset.camera_keys[0]]
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
# Apply each single transformation
for transform_name in transforms:
for t in transforms:
if t == transform_name:
cfg.image_transforms[t].weight = 1
else:
cfg.image_transforms[t].weight = 0
transform = make_image_transforms(cfg.image_transforms)
img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg: dict):
main(
cfg,
)
if __name__ == "__main__":
visualize_transforms_cli()

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

View File

@@ -0,0 +1,47 @@
from pathlib import Path
import torch
from torchvision.transforms import v2
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RangeRandomSharpness
from lerobot.common.utils.utils import seeded_context
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
ARTIFACT_DIR = "tests/data/save_image_transforms"
SEED = 1336
to_pil = v2.ToPILImage()
def main(repo_id):
dataset = LeRobotDataset(repo_id, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
# Get first frame of given episode
from_idx = dataset.episode_data_index["from"][0].item()
original_frame = dataset[from_idx][dataset.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
transforms = {
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
"sharpness": RangeRandomSharpness(0.0, 2.0),
}
# frames = {"original_frame": original_frame}
for name, transform in transforms.items():
with seeded_context(SEED):
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
transformed_frame = transform(original_frame)
# frames[name] = transform(original_frame)
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
if __name__ == "__main__":
repo_id = "lerobot/aloha_mobile_shrimp"
main(repo_id)

253
tests/test_transforms.py Normal file
View File

@@ -0,0 +1,253 @@
from pathlib import Path
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import pytest
import torch
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, get_image_transforms
from lerobot.common.utils.utils import seeded_context
# test_make_image_transforms
# -
# test backward compatibility torchvision
# - save artifacts
# test backward compatibility default yaml (enable false, enable true)
# - save artifacts
def test_get_image_transforms_no_transform():
get_image_transforms()
get_image_transforms(sharpness_weight=0.0)
get_image_transforms(max_num_transforms=0)
@pytest.fixture
def img():
# dataset = LeRobotDataset("lerobot/pusht")
# item = dataset[0]
# return item["observation.image"]
path = "tests/data/save_image_transforms/original_frame.png"
img_chw = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
return img_chw
def test_get_image_transforms_brightness(img):
brightness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(brightness_weight=1., brightness_min_max=brightness_min_max)
tf_expected = v2.ColorJitter(brightness=brightness_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_contrast(img):
contrast_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(contrast_weight=1., contrast_min_max=contrast_min_max)
tf_expected = v2.ColorJitter(contrast=contrast_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_saturation(img):
saturation_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(saturation_weight=1., saturation_min_max=saturation_min_max)
tf_expected = v2.ColorJitter(saturation=saturation_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_hue(img):
hue_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(hue_weight=1., hue_min_max=hue_min_max)
tf_expected = v2.ColorJitter(hue=hue_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_sharpness(img):
sharpness_min_max = (0.5, 0.5)
tf_actual = get_image_transforms(sharpness_weight=1., sharpness_min_max=sharpness_min_max)
tf_expected = RangeRandomSharpness(**sharpness_min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_max_num_transforms(img):
tf_actual = get_image_transforms(
saturation_min_max=(0.5, 0.5),
constrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
tf_expected = v2.Compose([
v2.ColorJitter(brightness=(0.5, 0.5)),
v2.ColorJitter(contrast=(0.5, 0.5)),
v2.ColorJitter(saturation=(0.5, 0.5)),
v2.ColorJitter(hue=(0.5, 0.5)),
RangeRandomSharpness(sharpness=(0.5, 0.5)),
])
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_random_order(img):
out_imgs = []
with seeded_context(1337):
for _ in range(20):
tf = get_image_transforms(
saturation_min_max=(0.5, 0.5),
constrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
out_imgs.append(tf(img))
for i in range(1,10):
with pytest.raises(ValueError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
def test_backward_compatibility_torchvision():
pass
def test_backward_compatibility_default_yaml():
pass
# class TestRandomSubsetApply:
# @pytest.fixture(autouse=True)
# def setup(self):
# self.jitters = [
# v2.ColorJitter(brightness=0.5),
# v2.ColorJitter(contrast=0.5),
# v2.ColorJitter(saturation=0.5),
# ]
# self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
# self.img = torch.rand(3, 224, 224)
# @pytest.mark.parametrize("p", [[0, 1], [1, 0]])
# def test_random_choice(self, p):
# random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
# output = random_choice(self.img)
# p_horz, _ = p
# if p_horz:
# torch.testing.assert_close(output, F.horizontal_flip(self.img))
# else:
# torch.testing.assert_close(output, F.vertical_flip(self.img))
# def test_transform_all(self):
# transform = RandomSubsetApply(self.jitters)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_transform_subset(self):
# transform = RandomSubsetApply(self.jitters, n_subset=2)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_random_order(self):
# random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# # We can't really check whether the transforms are actually applied in random order. However,
# # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# # applies them in random order, we can use a fixed order to compute the expected value.
# actual = random_order(self.img)
# expected = v2.Compose(self.flips)(self.img)
# torch.testing.assert_close(actual, expected)
# def test_probability_length_mismatch(self):
# with pytest.raises(ValueError):
# RandomSubsetApply(self.jitters, p=[0.5, 0.5])
# def test_invalid_n_subset(self):
# with pytest.raises(ValueError):
# RandomSubsetApply(self.jitters, n_subset=5)
# class TestRangeRandomSharpness:
# @pytest.fixture(autouse=True)
# def setup(self):
# self.img = torch.rand(3, 224, 224)
# def test_valid_range(self):
# transform = RangeRandomSharpness(0.1, 2.0)
# output = transform(self.img)
# assert output.shape == self.img.shape
# def test_invalid_range_min_negative(self):
# with pytest.raises(ValueError):
# RangeRandomSharpness(-0.1, 2.0)
# def test_invalid_range_max_smaller(self):
# with pytest.raises(ValueError):
# RangeRandomSharpness(2.0, 0.1)
# class TestMakeImageTransforms:
# @pytest.fixture(autouse=True)
# def setup(self):
# """Seed should be the same as the one that was used to generate artifacts"""
# self.config = {
# "enable": True,
# "max_num_transforms": 1,
# "random_order": False,
# "brightness": {"weight": 0, "min": 2.0, "max": 2.0},
# "contrast": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# "saturation": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# "hue": {
# "weight": 0,
# "min": 0.5,
# "max": 0.5,
# },
# "sharpness": {
# "weight": 0,
# "min": 2.0,
# "max": 2.0,
# },
# }
# self.path = Path("tests/data/save_image_transforms")
# self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
# self.transforms = {
# "brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
# "contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
# "saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
# "hue": v2.ColorJitter(hue=(0.5, 0.5)),
# "sharpness": RangeRandomSharpness(2.0, 2.0),
# }
# @staticmethod
# def load_png_to_tensor(path: Path):
# return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
# @pytest.mark.parametrize(
# "transform_key, seed",
# [
# ("brightness", 1336),
# ("contrast", 1336),
# ("saturation", 1336),
# ("hue", 1336),
# ("sharpness", 1336),
# ],
# )
# def test_single_transform(self, transform_key, seed):
# config = self.config
# config[transform_key]["weight"] = 1
# cfg = OmegaConf.create(config)
# actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
# with seeded_context(1336):
# actual = actual_t(self.original_frame)
# expected_t = self.transforms[transform_key]
# with seeded_context(1336):
# expected = expected_t(self.original_frame)
# torch.testing.assert_close(actual, expected)