diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index f7072c72..beacc48d 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -206,6 +206,11 @@ class ImageTransformsConfig: type="SharpnessJitter", kwargs={"sharpness": (0.5, 1.5)}, ), + "affine": ImageTransformConfig( + weight=1.0, + type="RandomAffine", + kwargs={"degrees": (-5.0, 5.0), "translate": (0.05, 0.05)}, + ), } ) @@ -217,6 +222,8 @@ def make_transform_from_config(cfg: ImageTransformConfig): return v2.ColorJitter(**cfg.kwargs) elif cfg.type == "SharpnessJitter": return SharpnessJitter(**cfg.kwargs) + elif cfg.type == "RandomAffine": + return v2.RandomAffine(**cfg.kwargs) else: raise ValueError(f"Transform '{cfg.type}' is not valid.") diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 98f95707..8a66ceb2 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -134,6 +134,25 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max): torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor)) +@pytest.mark.parametrize("degrees, translate", [((-5.0, 5.0), (0.05, 0.05)), ((10.0, 10.0), (0.1, 0.1))]) +def test_get_image_transforms_affine(img_tensor_factory, degrees, translate): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformsConfig( + enable=True, + tfs={ + "affine": ImageTransformConfig( + type="RandomAffine", kwargs={"degrees": degrees, "translate": translate} + ) + }, + ) + tf = ImageTransforms(tf_cfg) + output = tf(img_tensor) + # Verify output shape is preserved + assert output.shape == img_tensor.shape + # Verify transform is type RandomAffine + assert isinstance(tf.transforms["affine"], v2.RandomAffine) + + def test_get_image_transforms_max_num_transforms(img_tensor_factory): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig( @@ -262,7 +281,37 @@ def test_backward_compatibility_default_config(img_tensor, default_transforms): # NOTE: PyTorch versions have different randomness, it might break this test. # See this PR: https://github.com/huggingface/lerobot/pull/1127. - cfg = ImageTransformsConfig(enable=True) + # Use config without affine to match original test artifacts + cfg = ImageTransformsConfig( + enable=True, + tfs={ + "brightness": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"brightness": (0.8, 1.2)}, + ), + "contrast": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"contrast": (0.8, 1.2)}, + ), + "saturation": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"saturation": (0.5, 1.5)}, + ), + "hue": ImageTransformConfig( + weight=1.0, + type="ColorJitter", + kwargs={"hue": (-0.05, 0.05)}, + ), + "sharpness": ImageTransformConfig( + weight=1.0, + type="SharpnessJitter", + kwargs={"sharpness": (0.5, 1.5)}, + ), + }, + ) default_tf = ImageTransforms(cfg) with seeded_context(1337): @@ -368,7 +417,7 @@ def test_save_each_transform(img_tensor_factory, tmp_path): save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples) # Check if the transformed images exist for each transform type - transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"] + transforms = ["brightness", "contrast", "saturation", "hue", "sharpness", "affine"] for transform in transforms: transform_dir = tmp_path / transform assert transform_dir.exists(), f"{transform} directory was not created."