rename to image_transforms

This commit is contained in:
Simon Alibert
2024-06-06 16:50:22 +00:00
parent a86f387554
commit c45dd8f848
8 changed files with 23 additions and 24 deletions

View File

@@ -1,5 +1,5 @@
""" """
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images. This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned. The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
""" """
@@ -18,7 +18,7 @@ output_dir.mkdir(parents=True, exist_ok=True)
repo_id = "lerobot/aloha_static_tape" repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations # Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, transform=None) dataset = LeRobotDataset(repo_id, image_transforms=None)
# Get the index of the first observation in the first episode # Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item() first_idx = dataset.episode_data_index["from"][0].item()
@@ -41,7 +41,7 @@ transforms = v2.Compose(
) )
# Create another LeRobotDataset with the defined transformations # Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, transform=transforms) transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset # Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]

View File

@@ -19,7 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import make_transforms from lerobot.common.datasets.transforms import make_image_transforms
def resolve_delta_timestamps(cfg): def resolve_delta_timestamps(cfg):
@@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg) resolve_delta_timestamps(cfg)
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
if isinstance(cfg.dataset_repo_id, str): if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform, image_transforms=image_transforms,
) )
else: else:
dataset = MultiLeRobotDataset( dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform, image_transforms=image_transforms,
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):

View File

@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION, version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
transform: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.transform = transform self.transform = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided # load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
@@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION, version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
transform: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
@@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root=root, root=root,
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
transform=transform, image_transforms=image_transforms,
) )
for repo_id in repo_ids for repo_id in repo_ids
] ]
@@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.transform = transform self.transform = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets) self.stats = aggregate_stats(self._datasets)

View File

@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32): def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [ transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),

View File

@@ -58,10 +58,10 @@ wandb:
project: lerobot project: lerobot
notes: "" notes: ""
image_transform: image_transforms:
# brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class # brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class
enable: true enable: true
# A subset of these transforms will be applied for each batch. This is the maximum size of that subset. # A subset of these transforms will be applied for each batch. This is the maximum size of that subset.
max_num_transforms: 3 max_num_transforms: 3
# Torchvision suggest applying the transforms in the following order : brightness, contrast, saturation, hue # Torchvision suggest applying the transforms in the following order : brightness, contrast, saturation, hue
# sharpness can be applied at any time before or after (we choose after). # sharpness can be applied at any time before or after (we choose after).

View File

@@ -4,14 +4,13 @@ import hydra
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_transforms from lerobot.common.datasets.transforms import make_image_transforms
to_pil = ToPILImage() to_pil = ToPILImage()
def main(cfg, output_dir=Path("outputs/image_transforms")): def main(cfg, output_dir=Path("outputs/image_transforms")):
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1]) output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@@ -27,11 +26,11 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
for transform_name in transforms: for transform_name in transforms:
for t in transforms: for t in transforms:
if t == transform_name: if t == transform_name:
cfg.image_transform[t].weight = 1 cfg.image_transforms[t].weight = 1
else: else:
cfg.image_transform[t].weight = 0 cfg.image_transforms[t].weight = 0
transform = make_transforms(cfg.image_transform) transform = make_image_transforms(cfg.image_transforms)
img = transform(frame) img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)

View File

@@ -15,7 +15,7 @@ to_pil = v2.ToPILImage()
def main(repo_id): def main(repo_id):
dataset = LeRobotDataset(repo_id, transform=None) dataset = LeRobotDataset(repo_id, image_transforms=None)
output_dir = Path(ARTIFACT_DIR) output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -8,7 +8,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
from lerobot.common.datasets.utils import flatten_dict from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH from tests.utils import DEFAULT_CONFIG_PATH
@@ -147,7 +147,7 @@ class TestMakeTransforms:
config = self.config config = self.config
config[transform_key]["weight"] = 1 config[transform_key]["weight"] = 1
cfg = OmegaConf.create(config) cfg = OmegaConf.create(config)
transform = make_transforms(cfg, to_dtype=torch.uint8) transform = make_image_transforms(cfg, to_dtype=torch.uint8)
# expected_t = self.transforms[transform_key] # expected_t = self.transforms[transform_key]
with seeded_context(seed): with seeded_context(seed):