rename to image_transforms

This commit is contained in:
Simon Alibert
2024-06-06 16:50:22 +00:00
parent a86f387554
commit c45dd8f848
8 changed files with 23 additions and 24 deletions

View File

@@ -1,5 +1,5 @@
"""
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
"""
@@ -18,7 +18,7 @@ output_dir.mkdir(parents=True, exist_ok=True)
repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, transform=None)
dataset = LeRobotDataset(repo_id, image_transforms=None)
# Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item()
@@ -41,7 +41,7 @@ transforms = v2.Compose(
)
# Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, transform=transforms)
transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]

View File

@@ -19,7 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import make_transforms
from lerobot.common.datasets.transforms import make_image_transforms
def resolve_delta_timestamps(cfg):
@@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform,
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform,
image_transforms=image_transforms,
)
if cfg.get("override_dataset_stats"):

View File

@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.transform = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
@@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
image_transforms=image_transforms,
)
for repo_id in repo_ids
]
@@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.transform = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)

View File

@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),

View File

@@ -58,10 +58,10 @@ wandb:
project: lerobot
notes: ""
image_transform:
image_transforms:
# brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class
enable: true
# A subset of these transforms will be applied for each batch. This is the maximum size of that subset.
# A subset of these transforms will be applied for each batch. This is the maximum size of that subset.
max_num_transforms: 3
# Torchvision suggest applying the transforms in the following order : brightness, contrast, saturation, hue
# sharpness can be applied at any time before or after (we choose after).

View File

@@ -4,14 +4,13 @@ import hydra
from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_transforms
from lerobot.common.datasets.transforms import make_image_transforms
to_pil = ToPILImage()
def main(cfg, output_dir=Path("outputs/image_transforms")):
dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir.mkdir(parents=True, exist_ok=True)
@@ -27,11 +26,11 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
for transform_name in transforms:
for t in transforms:
if t == transform_name:
cfg.image_transform[t].weight = 1
cfg.image_transforms[t].weight = 1
else:
cfg.image_transform[t].weight = 0
cfg.image_transforms[t].weight = 0
transform = make_transforms(cfg.image_transform)
transform = make_image_transforms(cfg.image_transforms)
img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)

View File

@@ -15,7 +15,7 @@ to_pil = v2.ToPILImage()
def main(repo_id):
dataset = LeRobotDataset(repo_id, transform=None)
dataset = LeRobotDataset(repo_id, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -8,7 +8,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH
@@ -147,7 +147,7 @@ class TestMakeTransforms:
config = self.config
config[transform_key]["weight"] = 1
cfg = OmegaConf.create(config)
transform = make_transforms(cfg, to_dtype=torch.uint8)
transform = make_image_transforms(cfg, to_dtype=torch.uint8)
# expected_t = self.transforms[transform_key]
with seeded_context(seed):