[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -57,7 +57,9 @@ class RandomSubsetApply(Transform):
elif not isinstance(n_subset, int):
raise TypeError("n_subset should be an int or None")
elif not (1 <= n_subset <= len(transforms)):
raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]")
raise ValueError(
f"n_subset should be in the interval [1, {len(transforms)}]"
)
self.transforms = transforms
total = sum(p)
@@ -116,16 +118,22 @@ class SharpnessJitter(Transform):
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
raise ValueError(
"If sharpness is a single number, it must be non negative."
)
sharpness = [1.0 - sharpness, 1.0 + sharpness]
sharpness[0] = max(sharpness[0], 0.0)
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
sharpness = [float(v) for v in sharpness]
else:
raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.")
raise TypeError(
f"{sharpness=} should be a single number or a sequence with length 2."
)
if not 0.0 <= sharpness[0] <= sharpness[1]:
raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.")
raise ValueError(
f"sharpnesss values should be between (0., inf), but got {sharpness}."
)
return float(sharpness[0]), float(sharpness[1])
@@ -134,7 +142,9 @@ class SharpnessJitter(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
return self._call_kernel(
F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor
)
def get_image_transforms(
@@ -185,7 +195,11 @@ def get_image_transforms(
raise ValueError("The interpolation passed is not supported")
# Weight for resizing is always 1
weights.append(1.0)
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
transforms.append(
v2.Resize(
size=(image_size[0], image_size[1]), interpolation=interpolation_mode
)
)
if brightness_min_max is not None and brightness_weight > 0.0:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
@@ -219,4 +233,6 @@ def get_image_transforms(
return v2.Identity()
else:
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order)
return RandomSubsetApply(
transforms, p=weights, n_subset=n_subset, random_order=random_order
)