Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import torch
|
||||
@@ -65,6 +66,8 @@ class RandomSubsetApply(Transform):
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
self.selected_transforms = None
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
@@ -72,9 +75,9 @@ class RandomSubsetApply(Transform):
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
self.selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
for transform in self.selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
@@ -138,61 +141,109 @@ class SharpnessJitter(Transform):
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(
|
||||
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
|
||||
)
|
||||
if weight < 0.0:
|
||||
raise ValueError(
|
||||
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
|
||||
)
|
||||
@dataclass
|
||||
class ImageTransformConfig:
|
||||
"""
|
||||
For each transform, the following parameters are available:
|
||||
weight: This represents the multinomial probability (with no replacement)
|
||||
used for sampling the transform. If the sum of the weights is not 1,
|
||||
they will be normalized.
|
||||
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
|
||||
custom transform defined here.
|
||||
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
(following uniform distribution) when it's applied.
|
||||
"""
|
||||
|
||||
check_value("brightness", brightness_weight, brightness_min_max)
|
||||
check_value("contrast", contrast_weight, contrast_min_max)
|
||||
check_value("saturation", saturation_weight, saturation_min_max)
|
||||
check_value("hue", hue_weight, hue_min_max)
|
||||
check_value("sharpness", sharpness_weight, sharpness_min_max)
|
||||
weight: float = 1.0
|
||||
type: str = "Identity"
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None and contrast_weight > 0.0:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None and saturation_weight > 0.0:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None and hue_weight > 0.0:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
n_subset = min(n_subset, max_num_transforms)
|
||||
@dataclass
|
||||
class ImageTransformsConfig:
|
||||
"""
|
||||
These transforms are all using standard torchvision.transforms.v2
|
||||
You can find out how these transformations affect images here:
|
||||
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
We use a custom RandomSubsetApply container to sample them.
|
||||
"""
|
||||
|
||||
if n_subset == 0:
|
||||
return v2.Identity()
|
||||
# Set this flag to `true` to enable transforms during training
|
||||
enable: bool = False
|
||||
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [1, number_of_available_transforms].
|
||||
max_num_transforms: int = 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: bool = False
|
||||
tfs: dict[str, ImageTransformConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.8, 1.2)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.8, 1.2)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 1.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (-0.05, 0.05)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 1.5)},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
if cfg.type == "Identity":
|
||||
return v2.Identity(**cfg.kwargs)
|
||||
elif cfg.type == "ColorJitter":
|
||||
return v2.ColorJitter(**cfg.kwargs)
|
||||
elif cfg.type == "SharpnessJitter":
|
||||
return SharpnessJitter(**cfg.kwargs)
|
||||
else:
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
def __init__(self, cfg: ImageTransformsConfig) -> None:
|
||||
super().__init__()
|
||||
self._cfg = cfg
|
||||
|
||||
self.weights = []
|
||||
self.transforms = {}
|
||||
for tf_name, tf_cfg in cfg.tfs.items():
|
||||
if tf_cfg.weight <= 0.0:
|
||||
continue
|
||||
|
||||
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
|
||||
self.weights.append(tf_cfg.weight)
|
||||
|
||||
n_subset = min(len(self.transforms), cfg.max_num_transforms)
|
||||
if n_subset == 0 or not cfg.enable:
|
||||
self.tf = v2.Identity()
|
||||
else:
|
||||
self.tf = RandomSubsetApply(
|
||||
transforms=list(self.transforms.values()),
|
||||
p=self.weights,
|
||||
n_subset=n_subset,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
return self.tf(*inputs)
|
||||
|
||||
Reference in New Issue
Block a user