forked from tangger/lerobot
Compare commits
40 Commits
realman-si
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca81b1d6f4 | ||
|
|
e52942a200 | ||
|
|
b60810a8b6 | ||
|
|
faacb36271 | ||
|
|
c45dd8f848 | ||
|
|
a86f387554 | ||
|
|
bdc0ebd36a | ||
|
|
19f4a6568d | ||
|
|
9552a4f010 | ||
|
|
d657139828 | ||
|
|
b1714803a3 | ||
|
|
5d55b19cbd | ||
|
|
641d349df4 | ||
|
|
e444b0d529 | ||
|
|
8237ed9aa4 | ||
|
|
82e32f1fcd | ||
|
|
644e77e413 | ||
|
|
1b1bbb1632 | ||
|
|
0fb3dd745b | ||
|
|
4dbc1adb0d | ||
|
|
ceb95592af | ||
|
|
6509c3f6d4 | ||
|
|
8b134725d5 | ||
|
|
a544949ebe | ||
|
|
fdf56e7a62 | ||
|
|
443b06b412 | ||
|
|
22bd1f0669 | ||
|
|
31e3c82386 | ||
|
|
5eea2542d9 | ||
|
|
42f9cc9c2a | ||
|
|
66629a956d | ||
|
|
7be2c35c0a | ||
|
|
14291171cc | ||
|
|
cc4b3bd8e7 | ||
|
|
602ea9844b | ||
|
|
9f8415fa83 | ||
|
|
212a5ab29b | ||
|
|
c4870e5892 | ||
|
|
20a3715469 | ||
|
|
65e46a49e1 |
51
examples/6_add_image_transforms.py
Normal file
51
examples/6_add_image_transforms.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
|
||||
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
to_pil = ToPILImage()
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(repo_id, image_transforms=None)
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
|
||||
# Save the original frame
|
||||
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print(f"Original frame saved to {output_dir / 'original_frame.png'}.")
|
||||
|
||||
|
||||
# Define the transformations
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
|
||||
# Save the transformed frame
|
||||
to_pil(transformed_frame).save(output_dir / "transformed_frame.png", quality=100)
|
||||
print(f"Transformed frame saved to {output_dir / 'transformed_frame.png'}.")
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
@@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
image_transforms = None
|
||||
if cfg.image_transforms.enable:
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
contrast_weight=cfg.contrast.weight,
|
||||
contrast_min_max=cfg.contrast.min_max,
|
||||
saturation_weight=cfg.saturation.weight,
|
||||
saturation_min_max=cfg.saturation.min_max,
|
||||
hue_weight=cfg.hue.weight,
|
||||
hue_min_max=cfg.hue.min_max,
|
||||
sharpness_weight=cfg.sharpness.weight,
|
||||
sharpness_min_max=cfg.sharpness.min_max,
|
||||
max_num_transforms=cfg.max_num_transforms,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
||||
@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
@@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
return item
|
||||
|
||||
@@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.transform = transform
|
||||
obj.image_transforms = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
@@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
@@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
@@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
for data_key in self.disabled_data_keys:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
157
lerobot/common/datasets/transforms.py
Normal file
157
lerobot/common/datasets/transforms.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
class RandomSubsetApply(Transform):
|
||||
"""
|
||||
Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
Args:
|
||||
transforms (sequence or torch.nn.Module): list of transformations
|
||||
p (list of floats or None, optional): probability of each transform being picked.
|
||||
If ``p`` doesn't sum to 1, it is automatically normalized. If ``None``
|
||||
(default), all transforms have the same probability.
|
||||
n_subset (int or None): number of transformations to apply. If ``None``,
|
||||
all transforms are applied.
|
||||
random_order (bool): apply transformations in a random order
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: Sequence[Callable],
|
||||
p: list[float] | None = None,
|
||||
n_subset: int | None = None,
|
||||
random_order: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(transforms, Sequence):
|
||||
raise TypeError("Argument transforms should be a sequence of callables")
|
||||
if p is None:
|
||||
p = [1] * len(transforms)
|
||||
elif len(p) != len(transforms):
|
||||
raise ValueError(
|
||||
f"Length of p doesn't match the number of transforms: {len(p)} != {len(transforms)}"
|
||||
)
|
||||
|
||||
if n_subset is None:
|
||||
n_subset = len(transforms)
|
||||
elif not isinstance(n_subset, int):
|
||||
raise TypeError("n_subset should be an int or None")
|
||||
elif not (0 <= n_subset <= len(transforms)):
|
||||
raise ValueError(f"n_subset should be in the interval [0, {len(transforms)}]")
|
||||
|
||||
self.transforms = transforms
|
||||
total = sum(p)
|
||||
self.p = [prob / total for prob in p]
|
||||
self.n_subset = n_subset
|
||||
self.random_order = random_order
|
||||
|
||||
def forward(self, *inputs: Any) -> Any:
|
||||
needs_unpacking = len(inputs) > 1
|
||||
|
||||
selected_indices = torch.multinomial(torch.tensor(self.p), self.n_subset)
|
||||
if not self.random_order:
|
||||
selected_indices = selected_indices.sort().values
|
||||
|
||||
selected_transforms = [self.transforms[i] for i in selected_indices]
|
||||
|
||||
for transform in selected_transforms:
|
||||
outputs = transform(*inputs)
|
||||
inputs = outputs if needs_unpacking else (outputs,)
|
||||
|
||||
return outputs
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"transforms={self.transforms}, "
|
||||
f"p={self.p}, "
|
||||
f"n_subset={self.n_subset}, "
|
||||
f"random_order={self.random_order}"
|
||||
)
|
||||
|
||||
|
||||
class RangeRandomSharpness(Transform):
|
||||
"""Similar to v2.RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
|
||||
each time in [range_min, range_max].
|
||||
|
||||
If the input is a :class:`torch.Tensor`,
|
||||
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, range_min: float, range_max) -> None:
|
||||
super().__init__()
|
||||
self.range_min, self.range_max = self._check_input(range_min, range_max)
|
||||
|
||||
def _check_input(self, range_min, range_max):
|
||||
if range_min < 0:
|
||||
raise ValueError("range_min must be non negative.")
|
||||
if range_min > range_max:
|
||||
raise ValueError("range_max must greater or equal to range_min")
|
||||
return range_min, range_max
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1).item()
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
|
||||
def check_value_error(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.")
|
||||
if weight < 0.:
|
||||
raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).")
|
||||
|
||||
check_value_error("brightness", brightness_weight, brightness_min_max)
|
||||
check_value_error("contrast", contrast_weight, contrast_min_max)
|
||||
check_value_error("saturation", saturation_weight, saturation_min_max)
|
||||
check_value_error("hue", hue_weight, hue_min_max)
|
||||
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(RangeRandomSharpness(**sharpness_min_max))
|
||||
|
||||
if max_num_transforms is None:
|
||||
n_subset = len(transforms)
|
||||
else:
|
||||
n_subset = min(len(transforms), max_num_transforms)
|
||||
|
||||
final_transforms = RandomSubsetApply(
|
||||
transforms, p=weights, n_subset=n_subset, random_order=random_order
|
||||
)
|
||||
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return final_transforms
|
||||
@@ -57,3 +57,37 @@ wandb:
|
||||
disable_artifact: false
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
image_transforms:
|
||||
# These transforms are all using standard torchvision.transforms.v2
|
||||
# You can find out how these transformations affect images here:
|
||||
# https://pytorch.org/vision/0.18/auto_examples/transforms/plot_transforms_illustrations.html
|
||||
# We use a custom RandomSubsetApply container to sample them.
|
||||
# For each transform, the following parameters are available:
|
||||
# weight: This represents the multinomial probability (with no replacement)
|
||||
# used for sampling the transform. If the sum of the weights is not 1,
|
||||
# they will be normalized.
|
||||
# min_max: Lower & upper bound respectively used for sampling the transform's parameter
|
||||
# (following uniform distribution) when it's applied.
|
||||
enable: false
|
||||
# This is the number of transforms (sampled from these below) that will be applied to each frame.
|
||||
# It's an integer in the interval [0, number of available transforms].
|
||||
max_num_transforms: 3
|
||||
# By default, transforms are applied in Torchvision's suggested order (shown below).
|
||||
# Set this to True to apply them in a random order.
|
||||
random_order: false
|
||||
brightness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
contrast:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
saturation:
|
||||
weight: 1
|
||||
min_max: [0.5, 1.5]
|
||||
hue:
|
||||
weight: 1
|
||||
min_max: [-0.05, 0.05]
|
||||
sharpness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
|
||||
46
lerobot/scripts/visualize_image_transforms.py
Normal file
46
lerobot/scripts/visualize_image_transforms.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import make_image_transforms
|
||||
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def main(cfg, output_dir=Path("outputs/image_transforms")):
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
|
||||
|
||||
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get first frame of 1st episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
to_pil(frame).save(output_dir / "original_frame.png", quality=100)
|
||||
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
|
||||
# Apply each single transformation
|
||||
for transform_name in transforms:
|
||||
for t in transforms:
|
||||
if t == transform_name:
|
||||
cfg.image_transforms[t].weight = 1
|
||||
else:
|
||||
cfg.image_transforms[t].weight = 0
|
||||
|
||||
transform = make_image_transforms(cfg.image_transforms)
|
||||
img = transform(frame)
|
||||
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def visualize_transforms_cli(cfg: dict):
|
||||
main(
|
||||
cfg,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_transforms_cli()
|
||||
BIN
tests/data/save_image_transforms/original_frame.png
Normal file
BIN
tests/data/save_image_transforms/original_frame.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 185 KiB |
47
tests/scripts/save_image_transforms.py
Normal file
47
tests/scripts/save_image_transforms.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import RangeRandomSharpness
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
ARTIFACT_DIR = "tests/data/save_image_transforms"
|
||||
SEED = 1336
|
||||
to_pil = v2.ToPILImage()
|
||||
|
||||
|
||||
def main(repo_id):
|
||||
dataset = LeRobotDataset(repo_id, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get first frame of given episode
|
||||
from_idx = dataset.episode_data_index["from"][0].item()
|
||||
original_frame = dataset[from_idx][dataset.camera_keys[0]]
|
||||
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||
|
||||
transforms = {
|
||||
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
|
||||
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
|
||||
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
|
||||
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
|
||||
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
||||
}
|
||||
|
||||
# frames = {"original_frame": original_frame}
|
||||
for name, transform in transforms.items():
|
||||
with seeded_context(SEED):
|
||||
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
|
||||
transformed_frame = transform(original_frame)
|
||||
# frames[name] = transform(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
|
||||
|
||||
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
|
||||
|
||||
if __name__ == "__main__":
|
||||
repo_id = "lerobot/aloha_mobile_shrimp"
|
||||
main(repo_id)
|
||||
253
tests/test_transforms.py
Normal file
253
tests/test_transforms.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, get_image_transforms
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
|
||||
# test_make_image_transforms
|
||||
# -
|
||||
|
||||
# test backward compatibility torchvision
|
||||
# - save artifacts
|
||||
|
||||
# test backward compatibility default yaml (enable false, enable true)
|
||||
# - save artifacts
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform():
|
||||
get_image_transforms()
|
||||
get_image_transforms(sharpness_weight=0.0)
|
||||
get_image_transforms(max_num_transforms=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img():
|
||||
# dataset = LeRobotDataset("lerobot/pusht")
|
||||
# item = dataset[0]
|
||||
# return item["observation.image"]
|
||||
path = "tests/data/save_image_transforms/original_frame.png"
|
||||
img_chw = torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
return img_chw
|
||||
|
||||
def test_get_image_transforms_brightness(img):
|
||||
brightness_min_max = (0.5, 0.5)
|
||||
tf_actual = get_image_transforms(brightness_weight=1., brightness_min_max=brightness_min_max)
|
||||
tf_expected = v2.ColorJitter(brightness=brightness_min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
def test_get_image_transforms_contrast(img):
|
||||
contrast_min_max = (0.5, 0.5)
|
||||
tf_actual = get_image_transforms(contrast_weight=1., contrast_min_max=contrast_min_max)
|
||||
tf_expected = v2.ColorJitter(contrast=contrast_min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
def test_get_image_transforms_saturation(img):
|
||||
saturation_min_max = (0.5, 0.5)
|
||||
tf_actual = get_image_transforms(saturation_weight=1., saturation_min_max=saturation_min_max)
|
||||
tf_expected = v2.ColorJitter(saturation=saturation_min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
def test_get_image_transforms_hue(img):
|
||||
hue_min_max = (0.5, 0.5)
|
||||
tf_actual = get_image_transforms(hue_weight=1., hue_min_max=hue_min_max)
|
||||
tf_expected = v2.ColorJitter(hue=hue_min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
def test_get_image_transforms_sharpness(img):
|
||||
sharpness_min_max = (0.5, 0.5)
|
||||
tf_actual = get_image_transforms(sharpness_weight=1., sharpness_min_max=sharpness_min_max)
|
||||
tf_expected = RangeRandomSharpness(**sharpness_min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img):
|
||||
tf_actual = get_image_transforms(
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
constrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=False,
|
||||
)
|
||||
tf_expected = v2.Compose([
|
||||
v2.ColorJitter(brightness=(0.5, 0.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 0.5)),
|
||||
v2.ColorJitter(saturation=(0.5, 0.5)),
|
||||
v2.ColorJitter(hue=(0.5, 0.5)),
|
||||
RangeRandomSharpness(sharpness=(0.5, 0.5)),
|
||||
])
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
|
||||
|
||||
def test_get_image_transforms_random_order(img):
|
||||
out_imgs = []
|
||||
with seeded_context(1337):
|
||||
for _ in range(20):
|
||||
tf = get_image_transforms(
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
constrast_min_max=(0.5, 0.5),
|
||||
saturation_min_max=(0.5, 0.5),
|
||||
hue_min_max=(0.5, 0.5),
|
||||
sharpness_min_max=(0.5, 0.5),
|
||||
random_order=False,
|
||||
)
|
||||
out_imgs.append(tf(img))
|
||||
|
||||
for i in range(1,10):
|
||||
with pytest.raises(ValueError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
|
||||
def test_backward_compatibility_torchvision():
|
||||
pass
|
||||
|
||||
def test_backward_compatibility_default_yaml():
|
||||
pass
|
||||
|
||||
|
||||
# class TestRandomSubsetApply:
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def setup(self):
|
||||
# self.jitters = [
|
||||
# v2.ColorJitter(brightness=0.5),
|
||||
# v2.ColorJitter(contrast=0.5),
|
||||
# v2.ColorJitter(saturation=0.5),
|
||||
# ]
|
||||
# self.flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
# self.img = torch.rand(3, 224, 224)
|
||||
|
||||
# @pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
# def test_random_choice(self, p):
|
||||
# random_choice = RandomSubsetApply(self.flips, p=p, n_subset=1, random_order=False)
|
||||
# output = random_choice(self.img)
|
||||
|
||||
# p_horz, _ = p
|
||||
# if p_horz:
|
||||
# torch.testing.assert_close(output, F.horizontal_flip(self.img))
|
||||
# else:
|
||||
# torch.testing.assert_close(output, F.vertical_flip(self.img))
|
||||
|
||||
# def test_transform_all(self):
|
||||
# transform = RandomSubsetApply(self.jitters)
|
||||
# output = transform(self.img)
|
||||
# assert output.shape == self.img.shape
|
||||
|
||||
# def test_transform_subset(self):
|
||||
# transform = RandomSubsetApply(self.jitters, n_subset=2)
|
||||
# output = transform(self.img)
|
||||
# assert output.shape == self.img.shape
|
||||
|
||||
# def test_random_order(self):
|
||||
# random_order = RandomSubsetApply(self.flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# # We can't really check whether the transforms are actually applied in random order. However,
|
||||
# # horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||
# # applies them in random order, we can use a fixed order to compute the expected value.
|
||||
# actual = random_order(self.img)
|
||||
# expected = v2.Compose(self.flips)(self.img)
|
||||
# torch.testing.assert_close(actual, expected)
|
||||
|
||||
# def test_probability_length_mismatch(self):
|
||||
# with pytest.raises(ValueError):
|
||||
# RandomSubsetApply(self.jitters, p=[0.5, 0.5])
|
||||
|
||||
# def test_invalid_n_subset(self):
|
||||
# with pytest.raises(ValueError):
|
||||
# RandomSubsetApply(self.jitters, n_subset=5)
|
||||
|
||||
|
||||
# class TestRangeRandomSharpness:
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def setup(self):
|
||||
# self.img = torch.rand(3, 224, 224)
|
||||
|
||||
# def test_valid_range(self):
|
||||
# transform = RangeRandomSharpness(0.1, 2.0)
|
||||
# output = transform(self.img)
|
||||
# assert output.shape == self.img.shape
|
||||
|
||||
# def test_invalid_range_min_negative(self):
|
||||
# with pytest.raises(ValueError):
|
||||
# RangeRandomSharpness(-0.1, 2.0)
|
||||
|
||||
# def test_invalid_range_max_smaller(self):
|
||||
# with pytest.raises(ValueError):
|
||||
# RangeRandomSharpness(2.0, 0.1)
|
||||
|
||||
|
||||
# class TestMakeImageTransforms:
|
||||
# @pytest.fixture(autouse=True)
|
||||
# def setup(self):
|
||||
# """Seed should be the same as the one that was used to generate artifacts"""
|
||||
# self.config = {
|
||||
# "enable": True,
|
||||
# "max_num_transforms": 1,
|
||||
# "random_order": False,
|
||||
# "brightness": {"weight": 0, "min": 2.0, "max": 2.0},
|
||||
# "contrast": {
|
||||
# "weight": 0,
|
||||
# "min": 2.0,
|
||||
# "max": 2.0,
|
||||
# },
|
||||
# "saturation": {
|
||||
# "weight": 0,
|
||||
# "min": 2.0,
|
||||
# "max": 2.0,
|
||||
# },
|
||||
# "hue": {
|
||||
# "weight": 0,
|
||||
# "min": 0.5,
|
||||
# "max": 0.5,
|
||||
# },
|
||||
# "sharpness": {
|
||||
# "weight": 0,
|
||||
# "min": 2.0,
|
||||
# "max": 2.0,
|
||||
# },
|
||||
# }
|
||||
# self.path = Path("tests/data/save_image_transforms")
|
||||
# self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
|
||||
# self.transforms = {
|
||||
# "brightness": v2.ColorJitter(brightness=(2.0, 2.0)),
|
||||
# "contrast": v2.ColorJitter(contrast=(2.0, 2.0)),
|
||||
# "saturation": v2.ColorJitter(saturation=(2.0, 2.0)),
|
||||
# "hue": v2.ColorJitter(hue=(0.5, 0.5)),
|
||||
# "sharpness": RangeRandomSharpness(2.0, 2.0),
|
||||
# }
|
||||
|
||||
# @staticmethod
|
||||
# def load_png_to_tensor(path: Path):
|
||||
# return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "transform_key, seed",
|
||||
# [
|
||||
# ("brightness", 1336),
|
||||
# ("contrast", 1336),
|
||||
# ("saturation", 1336),
|
||||
# ("hue", 1336),
|
||||
# ("sharpness", 1336),
|
||||
# ],
|
||||
# )
|
||||
# def test_single_transform(self, transform_key, seed):
|
||||
# config = self.config
|
||||
# config[transform_key]["weight"] = 1
|
||||
# cfg = OmegaConf.create(config)
|
||||
|
||||
# actual_t = make_image_transforms(cfg, to_dtype=torch.uint8)
|
||||
# with seeded_context(1336):
|
||||
# actual = actual_t(self.original_frame)
|
||||
|
||||
# expected_t = self.transforms[transform_key]
|
||||
# with seeded_context(1336):
|
||||
# expected = expected_t(self.original_frame)
|
||||
|
||||
# torch.testing.assert_close(actual, expected)
|
||||
Reference in New Issue
Block a user