[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
parent d8a1758122
commit 584cad808e
108 changed files with 3894 additions and 1189 deletions

View File

@@ -21,7 +21,11 @@ from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.datasets.transforms import (
RandomSubsetApply,
SharpnessJitter,
get_image_transforms,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
@@ -51,7 +55,9 @@ def default_transforms():
def test_get_image_transforms_no_transform(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5), max_num_transforms=0
)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@@ -149,7 +155,9 @@ def test_get_image_transforms_random_order(img_tensor_factory):
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
def test_backward_compatibility_torchvision(
img_tensor_factory, transform, min_max_values, single_transforms
):
img_tensor = img_tensor_factory()
for min_max in min_max_values:
kwargs = {
@@ -268,23 +276,33 @@ def test_sharpness_jitter_invalid_range_max_smaller():
],
)
def test_visualize_image_transforms(repo_id, n_examples):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"]
)
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
output_dir = output_dir / repo_id.split("/")[-1]
# Check if the original frame image exists
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
assert (
output_dir / "original_frame.png"
).exists(), "Original frame image was not saved."
# Check if the transformed images exist for each transform type
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
for transform in transforms:
transform_dir = output_dir / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
assert any(
transform_dir.iterdir()
), f"No transformed images found in {transform} directory."
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
"min.png",
"max.png",
"mean.png",
]
for file_name in expected_files:
assert (
transform_dir / file_name
@@ -292,7 +310,9 @@ def test_visualize_image_transforms(repo_id, n_examples):
# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = output_dir / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert (
combined_transforms_dir.exists()
), "Combined transforms directory was not created."
assert any(
combined_transforms_dir.iterdir()
), "No transformed images found in combined transforms directory."