added visualization for min and max transforms (#271)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
This commit is contained in:
committed by
GitHub
parent
a92d79fff2
commit
e28fa2344c
@@ -26,6 +26,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
@@ -258,3 +259,44 @@ def test_sharpness_jitter_invalid_range_min_negative():
|
||||
def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
with pytest.raises(ValueError):
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, n_examples",
|
||||
[
|
||||
("lerobot/aloha_sim_transfer_cube_human", 3),
|
||||
],
|
||||
)
|
||||
def test_visualize_image_transforms(repo_id, n_examples):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
|
||||
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
|
||||
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
|
||||
output_dir = output_dir / repo_id.split("/")[-1]
|
||||
|
||||
# Check if the original frame image exists
|
||||
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
|
||||
|
||||
# Check if the transformed images exist for each transform type
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
for transform in transforms:
|
||||
transform_dir = output_dir / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
for file_name in expected_files:
|
||||
assert (
|
||||
transform_dir / file_name
|
||||
).exists(), f"{file_name} was not found in {transform} directory."
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = output_dir / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
for i in range(1, n_examples + 1):
|
||||
assert (
|
||||
combined_transforms_dir / f"{i}.png"
|
||||
).exists(), f"Combined transform image {i}.png was not found."
|
||||
|
||||
Reference in New Issue
Block a user