Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -15,15 +15,12 @@
# limitations under the License.
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
@@ -33,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img():
dataset = LeRobotDataset(DATASET_REPO_ID)
return dataset[0][dataset.camera_keys[0]]
@pytest.fixture
def img_random():
return torch.rand(3, 480, 640)
@pytest.fixture
def color_jitters():
return [
@@ -67,47 +49,54 @@ def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img):
def test_get_image_transforms_no_transform(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img), img)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img, min_max):
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img, min_max):
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img, min_max):
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
def test_get_image_transforms_max_num_transforms(img):
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@@ -125,12 +114,13 @@ def test_get_image_transforms_max_num_transforms(img):
SharpnessJitter(sharpness=(0.5, 0.5)),
]
)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img):
def test_get_image_transforms_random_order(img_tensor_factory):
out_imgs = []
img_tensor = img_tensor_factory()
tf = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@@ -141,13 +131,14 @@ def test_get_image_transforms_random_order(img):
)
with seeded_context(1337):
for _ in range(10):
out_imgs.append(tf(img))
out_imgs.append(tf(img_tensor))
for i in range(1, len(out_imgs)):
with pytest.raises(AssertionError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"transform, min_max_values",
[
@@ -158,21 +149,24 @@ def test_get_image_transforms_random_order(img):
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
img_tensor = img_tensor_factory()
for min_max in min_max_values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
actual = tf(img)
actual = tf(img_tensor)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
expected = single_transforms[key]
torch.testing.assert_close(actual, expected)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
img_tensor = img_tensor_factory()
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
@@ -191,7 +185,7 @@ def test_backward_compatibility_default_config(img, default_transforms):
)
with seeded_context(1337):
actual = default_tf(img)
actual = default_tf(img_tensor)
expected = default_transforms["default"]
@@ -199,33 +193,36 @@ def test_backward_compatibility_default_config(img, default_transforms):
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img):
def test_random_subset_apply_single_choice(img_tensor_factory, p):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img)
actual = random_choice(img_tensor)
p_horz, _ = p
if p_horz:
torch.testing.assert_close(actual, F.horizontal_flip(img))
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
else:
torch.testing.assert_close(actual, F.vertical_flip(img))
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
def test_random_subset_apply_random_order(img):
def test_random_subset_apply_random_order(img_tensor_factory):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# applies them in random order, we can use a fixed order to compute the expected value.
actual = random_order(img)
expected = v2.Compose(flips)(img)
actual = random_order(img_tensor)
expected = v2.Compose(flips)(img_tensor)
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img):
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
img_tensor = img_tensor_factory()
transform = RandomSubsetApply(color_jitters)
output = transform(img)
assert output.shape == img.shape
output = transform(img_tensor)
assert output.shape == img_tensor.shape
def test_random_subset_apply_probability_length_mismatch(color_jitters):
@@ -239,16 +236,18 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img):
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter((0.1, 2.0))
output = tf(img)
assert output.shape == img.shape
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_valid_range_float(img):
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter(0.5)
output = tf(img)
assert output.shape == img.shape
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_invalid_range_min_negative():
@@ -261,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"repo_id, n_examples",
[