Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from dataclasses import dataclass, field
from typing import Any, Callable, Sequence
import torch
@@ -65,6 +66,8 @@ class RandomSubsetApply(Transform):
self.n_subset = n_subset
self.random_order = random_order
self.selected_transforms = None
def forward(self, *inputs: Any) -> Any:
needs_unpacking = len(inputs) > 1
@@ -72,9 +75,9 @@ class RandomSubsetApply(Transform):
if not self.random_order:
selected_indices = selected_indices.sort().values
selected_transforms = [self.transforms[i] for i in selected_indices]
self.selected_transforms = [self.transforms[i] for i in selected_indices]
for transform in selected_transforms:
for transform in self.selected_transforms:
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
@@ -138,61 +141,109 @@ class SharpnessJitter(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(
f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided."
)
if weight < 0.0:
raise ValueError(
f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight})."
)
@dataclass
class ImageTransformConfig:
"""
For each transform, the following parameters are available:
weight: This represents the multinomial probability (with no replacement)
used for sampling the transform. If the sum of the weights is not 1,
they will be normalized.
type: The name of the class used. This is either a class available under torchvision.transforms.v2 or a
custom transform defined here.
kwargs: Lower & upper bound respectively used for sampling the transform's parameter
(following uniform distribution) when it's applied.
"""
check_value("brightness", brightness_weight, brightness_min_max)
check_value("contrast", contrast_weight, contrast_min_max)
check_value("saturation", saturation_weight, saturation_min_max)
check_value("hue", hue_weight, hue_min_max)
check_value("sharpness", sharpness_weight, sharpness_min_max)
weight: float = 1.0
type: str = "Identity"
kwargs: dict[str, Any] = field(default_factory=dict)
weights = []
transforms = []
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None and contrast_weight > 0.0:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None and saturation_weight > 0.0:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None and hue_weight > 0.0:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None and sharpness_weight > 0.0:
weights.append(sharpness_weight)
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
n_subset = len(transforms)
if max_num_transforms is not None:
n_subset = min(n_subset, max_num_transforms)
@dataclass
class ImageTransformsConfig:
"""
These transforms are all using standard torchvision.transforms.v2
You can find out how these transformations affect images here:
https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
We use a custom RandomSubsetApply container to sample them.
"""
if n_subset == 0:
return v2.Identity()
# Set this flag to `true` to enable transforms during training
enable: bool = False
# This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
# It's an integer in the interval [1, number_of_available_transforms].
max_num_transforms: int = 3
# By default, transforms are applied in Torchvision's suggested order (shown below).
# Set this to True to apply them in a random order.
random_order: bool = False
tfs: dict[str, ImageTransformConfig] = field(
default_factory=lambda: {
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
}
)
def make_transform_from_config(cfg: ImageTransformConfig):
if cfg.type == "Identity":
return v2.Identity(**cfg.kwargs)
elif cfg.type == "ColorJitter":
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
raise ValueError(f"Transform '{cfg.type}' is not valid.")
class ImageTransforms(Transform):
"""A class to compose image transforms based on configuration."""
def __init__(self, cfg: ImageTransformsConfig) -> None:
super().__init__()
self._cfg = cfg
self.weights = []
self.transforms = {}
for tf_name, tf_cfg in cfg.tfs.items():
if tf_cfg.weight <= 0.0:
continue
self.transforms[tf_name] = make_transform_from_config(tf_cfg)
self.weights.append(tf_cfg.weight)
n_subset = min(len(self.transforms), cfg.max_num_transforms)
if n_subset == 0 or not cfg.enable:
self.tf = v2.Identity()
else:
self.tf = RandomSubsetApply(
transforms=list(self.transforms.values()),
p=self.weights,
n_subset=n_subset,
random_order=cfg.random_order,
)
def forward(self, *inputs: Any) -> Any:
return self.tf(*inputs)