[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -21,7 +21,11 @@ from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.datasets.transforms import (
|
||||
RandomSubsetApply,
|
||||
SharpnessJitter,
|
||||
get_image_transforms,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
@@ -51,7 +55,9 @@ def default_transforms():
|
||||
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5), max_num_transforms=0
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@@ -149,7 +155,9 @@ def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
def test_backward_compatibility_torchvision(
|
||||
img_tensor_factory, transform, min_max_values, single_transforms
|
||||
):
|
||||
img_tensor = img_tensor_factory()
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
@@ -268,23 +276,33 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
],
|
||||
)
|
||||
def test_visualize_image_transforms(repo_id, n_examples):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"]
|
||||
)
|
||||
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
|
||||
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
|
||||
output_dir = output_dir / repo_id.split("/")[-1]
|
||||
|
||||
# Check if the original frame image exists
|
||||
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
|
||||
assert (
|
||||
output_dir / "original_frame.png"
|
||||
).exists(), "Original frame image was not saved."
|
||||
|
||||
# Check if the transformed images exist for each transform type
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
for transform in transforms:
|
||||
transform_dir = output_dir / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
assert any(
|
||||
transform_dir.iterdir()
|
||||
), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
|
||||
"min.png",
|
||||
"max.png",
|
||||
"mean.png",
|
||||
]
|
||||
for file_name in expected_files:
|
||||
assert (
|
||||
transform_dir / file_name
|
||||
@@ -292,7 +310,9 @@ def test_visualize_image_transforms(repo_id, n_examples):
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = output_dir / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert (
|
||||
combined_transforms_dir.exists()
|
||||
), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
|
||||
Reference in New Issue
Block a user