add affine transforms and test (#2145)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -206,6 +206,11 @@ class ImageTransformsConfig:
|
|||||||
type="SharpnessJitter",
|
type="SharpnessJitter",
|
||||||
kwargs={"sharpness": (0.5, 1.5)},
|
kwargs={"sharpness": (0.5, 1.5)},
|
||||||
),
|
),
|
||||||
|
"affine": ImageTransformConfig(
|
||||||
|
weight=1.0,
|
||||||
|
type="RandomAffine",
|
||||||
|
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
|||||||
return v2.ColorJitter(**cfg.kwargs)
|
return v2.ColorJitter(**cfg.kwargs)
|
||||||
elif cfg.type == "SharpnessJitter":
|
elif cfg.type == "SharpnessJitter":
|
||||||
return SharpnessJitter(**cfg.kwargs)
|
return SharpnessJitter(**cfg.kwargs)
|
||||||
|
elif cfg.type == "RandomAffine":
|
||||||
|
return v2.RandomAffine(**cfg.kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
|
|||||||
@@ -134,6 +134,25 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
|||||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("degrees, translate", [((-5.0, 5.0), (0.05, 0.05)), ((10.0, 10.0), (0.1, 0.1))])
|
||||||
|
def test_get_image_transforms_affine(img_tensor_factory, degrees, translate):
|
||||||
|
img_tensor = img_tensor_factory()
|
||||||
|
tf_cfg = ImageTransformsConfig(
|
||||||
|
enable=True,
|
||||||
|
tfs={
|
||||||
|
"affine": ImageTransformConfig(
|
||||||
|
type="RandomAffine", kwargs={"degrees": degrees, "translate": translate}
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tf = ImageTransforms(tf_cfg)
|
||||||
|
output = tf(img_tensor)
|
||||||
|
# Verify output shape is preserved
|
||||||
|
assert output.shape == img_tensor.shape
|
||||||
|
# Verify transform is type RandomAffine
|
||||||
|
assert isinstance(tf.transforms["affine"], v2.RandomAffine)
|
||||||
|
|
||||||
|
|
||||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||||
img_tensor = img_tensor_factory()
|
img_tensor = img_tensor_factory()
|
||||||
tf_cfg = ImageTransformsConfig(
|
tf_cfg = ImageTransformsConfig(
|
||||||
@@ -262,7 +281,37 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
|||||||
# NOTE: PyTorch versions have different randomness, it might break this test.
|
# NOTE: PyTorch versions have different randomness, it might break this test.
|
||||||
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
|
||||||
|
|
||||||
cfg = ImageTransformsConfig(enable=True)
|
# Use config without affine to match original test artifacts
|
||||||
|
cfg = ImageTransformsConfig(
|
||||||
|
enable=True,
|
||||||
|
tfs={
|
||||||
|
"brightness": ImageTransformConfig(
|
||||||
|
weight=1.0,
|
||||||
|
type="ColorJitter",
|
||||||
|
kwargs={"brightness": (0.8, 1.2)},
|
||||||
|
),
|
||||||
|
"contrast": ImageTransformConfig(
|
||||||
|
weight=1.0,
|
||||||
|
type="ColorJitter",
|
||||||
|
kwargs={"contrast": (0.8, 1.2)},
|
||||||
|
),
|
||||||
|
"saturation": ImageTransformConfig(
|
||||||
|
weight=1.0,
|
||||||
|
type="ColorJitter",
|
||||||
|
kwargs={"saturation": (0.5, 1.5)},
|
||||||
|
),
|
||||||
|
"hue": ImageTransformConfig(
|
||||||
|
weight=1.0,
|
||||||
|
type="ColorJitter",
|
||||||
|
kwargs={"hue": (-0.05, 0.05)},
|
||||||
|
),
|
||||||
|
"sharpness": ImageTransformConfig(
|
||||||
|
weight=1.0,
|
||||||
|
type="SharpnessJitter",
|
||||||
|
kwargs={"sharpness": (0.5, 1.5)},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
default_tf = ImageTransforms(cfg)
|
default_tf = ImageTransforms(cfg)
|
||||||
|
|
||||||
with seeded_context(1337):
|
with seeded_context(1337):
|
||||||
@@ -368,7 +417,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
|
|||||||
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
|
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||||
|
|
||||||
# Check if the transformed images exist for each transform type
|
# Check if the transformed images exist for each transform type
|
||||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness", "affine"]
|
||||||
for transform in transforms:
|
for transform in transforms:
|
||||||
transform_dir = tmp_path / transform
|
transform_dir = tmp_path / transform
|
||||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||||
|
|||||||
Reference in New Issue
Block a user