Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -21,13 +20,21 @@ from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
from lerobot.common.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
RandomSubsetApply,
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -44,21 +51,38 @@ def single_transforms():
|
||||
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_tensor(single_transforms):
|
||||
return single_transforms["original_frame"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
def test_get_image_transforms_no_transform_enable_false(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
tf_cfg = ImageTransformsConfig() # default is enable=False
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform_max_num_transforms_0(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True, max_num_transforms=0)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -66,7 +90,10 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -74,7 +101,11 @@ def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -82,7 +113,10 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
@@ -90,21 +124,49 @@ def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=False,
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
max_num_transforms=5,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 0.5)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.5, 0.5)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 0.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (0.5, 0.5)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 0.5)},
|
||||
),
|
||||
},
|
||||
)
|
||||
tf_actual = ImageTransforms(tf_cfg)
|
||||
tf_expected = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||
@@ -121,68 +183,79 @@ def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
out_imgs = []
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
tf_cfg = ImageTransformsConfig(
|
||||
enable=True,
|
||||
random_order=True,
|
||||
tfs={
|
||||
"brightness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"brightness": (0.5, 0.5)},
|
||||
),
|
||||
"contrast": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"contrast": (0.5, 0.5)},
|
||||
),
|
||||
"saturation": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"saturation": (0.5, 0.5)},
|
||||
),
|
||||
"hue": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="ColorJitter",
|
||||
kwargs={"hue": (0.5, 0.5)},
|
||||
),
|
||||
"sharpness": ImageTransformConfig(
|
||||
weight=1.0,
|
||||
type="SharpnessJitter",
|
||||
kwargs={"sharpness": (0.5, 0.5)},
|
||||
),
|
||||
},
|
||||
)
|
||||
with seeded_context(1337):
|
||||
tf = ImageTransforms(tf_cfg)
|
||||
|
||||
with seeded_context(1338):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img_tensor))
|
||||
|
||||
tmp_img_tensor = img_tensor
|
||||
for sub_tf in tf.tf.selected_transforms:
|
||||
tmp_img_tensor = sub_tf(tmp_img_tensor)
|
||||
torch.testing.assert_close(tmp_img_tensor, out_imgs[-1])
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
"tf_type, tf_name, min_max_values",
|
||||
[
|
||||
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "brightness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "contrast", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "saturation", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
("ColorJitter", "hue", [(-0.25, -0.25), (0.25, 0.25)]),
|
||||
("SharpnessJitter", "sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
def test_backward_compatibility_single_transforms(
|
||||
img_tensor, tf_type, tf_name, min_max_values, single_transforms
|
||||
):
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
tf_cfg = ImageTransformConfig(type=tf_type, kwargs={tf_name: min_max})
|
||||
tf = make_transform_from_config(tf_cfg)
|
||||
actual = tf(img_tensor)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
key = f"{tf_name}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
def test_backward_compatibility_default_config(img_tensor, default_transforms):
|
||||
cfg = ImageTransformsConfig(enable=True)
|
||||
default_tf = ImageTransforms(cfg)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img_tensor)
|
||||
@@ -260,26 +333,36 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, n_examples",
|
||||
[
|
||||
("lerobot/aloha_sim_transfer_cube_human", 3),
|
||||
],
|
||||
)
|
||||
def test_visualize_image_transforms(repo_id, n_examples):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
|
||||
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
|
||||
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
|
||||
output_dir = output_dir / repo_id.split("/")[-1]
|
||||
def test_save_all_transforms(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
n_examples = 3
|
||||
|
||||
# Check if the original frame image exists
|
||||
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
|
||||
save_all_transforms(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = tmp_path / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (
|
||||
combined_transforms_dir / f"{i}.png"
|
||||
).exists(), f"Combined transform image {i}.png was not found."
|
||||
|
||||
|
||||
def test_save_each_transform(img_tensor_factory, tmp_path):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_cfg = ImageTransformsConfig(enable=True)
|
||||
n_examples = 3
|
||||
|
||||
save_each_transform(tf_cfg, img_tensor, tmp_path, n_examples)
|
||||
|
||||
# Check if the transformed images exist for each transform type
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
for transform in transforms:
|
||||
transform_dir = output_dir / transform
|
||||
transform_dir = tmp_path / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
|
||||
@@ -289,14 +372,3 @@ def test_visualize_image_transforms(repo_id, n_examples):
|
||||
assert (
|
||||
transform_dir / file_name
|
||||
).exists(), f"{file_name} was not found in {transform} directory."
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = output_dir / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (
|
||||
combined_transforms_dir / f"{i}.png"
|
||||
).exists(), f"Combined transform image {i}.png was not found."
|
||||
|
||||
Reference in New Issue
Block a user