add affine transforms and test (#2145)

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Bryson Jones
2025-10-19 12:39:30 -07:00
committed by GitHub
parent a95b15ccc0
commit 88100943ef
2 changed files with 58 additions and 2 deletions

View File

@@ -206,6 +206,11 @@ class ImageTransformsConfig:
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
"affine": ImageTransformConfig(
weight=1.0,
type="RandomAffine",
kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)},
),
}
)
@@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig):
return v2.ColorJitter(**cfg.kwargs)
elif cfg.type == "SharpnessJitter":
return SharpnessJitter(**cfg.kwargs)
elif cfg.type == "RandomAffine":
return v2.RandomAffine(**cfg.kwargs)
else:
raise ValueError(f"Transform '{cfg.type}' is not valid.")

View File

@@ -134,6 +134,25 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("degrees, translate", [((-5.0, 5.0), (0.05, 0.05)), ((10.0, 10.0), (0.1, 0.1))])
def test_get_image_transforms_affine(img_tensor_factory, degrees, translate):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={
"affine": ImageTransformConfig(
type="RandomAffine", kwargs={"degrees": degrees, "translate": translate}
)
},
)
tf = ImageTransforms(tf_cfg)
output = tf(img_tensor)
# Verify output shape is preserved
assert output.shape == img_tensor.shape
# Verify transform is type RandomAffine
assert isinstance(tf.transforms["affine"], v2.RandomAffine)
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
@@ -262,7 +281,37 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms):
# NOTE: PyTorch versions have different randomness, it might break this test.
# See this PR: https://github.com/huggingface/lerobot/pull/1127.
cfg = ImageTransformsConfig(enable=True)
# Use config without affine to match original test artifacts
cfg = ImageTransformsConfig(
enable=True,
tfs={
"brightness": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"brightness": (0.8, 1.2)},
),
"contrast": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"contrast": (0.8, 1.2)},
),
"saturation": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"saturation": (0.5, 1.5)},
),
"hue": ImageTransformConfig(
weight=1.0,
type="ColorJitter",
kwargs={"hue": (-0.05, 0.05)},
),
"sharpness": ImageTransformConfig(
weight=1.0,
type="SharpnessJitter",
kwargs={"sharpness": (0.5, 1.5)},
),
},
)
default_tf = ImageTransforms(cfg)
with seeded_context(1337):
@@ -368,7 +417,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
# Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness", "affine"]
for transform in transforms:
transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created."