diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py index be0eb957..6021c160 100644 --- a/examples/6_add_image_transforms.py +++ b/examples/6_add_image_transforms.py @@ -1,5 +1,5 @@ """ -This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images. +This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images. The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned. """ @@ -18,7 +18,7 @@ output_dir.mkdir(parents=True, exist_ok=True) repo_id = "lerobot/aloha_static_tape" # Create a LeRobotDataset with no transformations -dataset = LeRobotDataset(repo_id, transform=None) +dataset = LeRobotDataset(repo_id, image_transforms=None) # Get the index of the first observation in the first episode first_idx = dataset.episode_data_index["from"][0].item() @@ -41,7 +41,7 @@ transforms = v2.Compose( ) # Create another LeRobotDataset with the defined transformations -transformed_dataset = LeRobotDataset(repo_id, transform=transforms) +transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms) # Get a frame from the transformed dataset transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 370d1640..bae0677e 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -19,7 +19,7 @@ import torch from omegaconf import ListConfig, OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset -from lerobot.common.datasets.transforms import make_transforms +from lerobot.common.datasets.transforms import make_image_transforms def resolve_delta_timestamps(cfg): @@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData resolve_delta_timestamps(cfg) - transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None + image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None if isinstance(cfg.dataset_repo_id, str): dataset = LeRobotDataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), - transform=transform, + image_transforms=image_transforms, ) else: dataset = MultiLeRobotDataset( cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps"), - transform=transform, + image_transforms=image_transforms, ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 7da5f430..f41b8f7a 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.version = version self.root = root self.split = split - self.transform = transform + self.transform = image_transforms self.delta_timestamps = delta_timestamps # load data from hub or locally when root is provided # TODO(rcadene, aliberts): implement faster transfer @@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): version: str | None = CODEBASE_VERSION, root: Path | None = DATA_DIR, split: str = "train", - transform: Callable | None = None, + image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, ): super().__init__() @@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root=root, split=split, delta_timestamps=delta_timestamps, - transform=transform, + image_transforms=image_transforms, ) for repo_id in repo_ids ] @@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self.version = version self.root = root self.split = split - self.transform = transform + self.transform = image_transforms self.delta_timestamps = delta_timestamps self.stats = aggregate_stats(self._datasets) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index eeab633d..5a1064bb 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform): return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) -def make_transforms(cfg, to_dtype: torch.dtype = torch.float32): +def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32): transforms_list = [ v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index e1be2012..e30fb638 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -58,10 +58,10 @@ wandb: project: lerobot notes: "" -image_transform: +image_transforms: # brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class enable: true - # A subset of these transforms will be applied for each batch. This is the maximum size of that subset. + # A subset of these transforms will be applied for each batch. This is the maximum size of that subset. max_num_transforms: 3 # Torchvision suggest applying the transforms in the following order : brightness, contrast, saturation, hue # sharpness can be applied at any time before or after (we choose after). diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index 5b00bcab..332fa82a 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -4,14 +4,13 @@ import hydra from torchvision.transforms import ToPILImage from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.transforms import make_transforms +from lerobot.common.datasets.transforms import make_image_transforms to_pil = ToPILImage() def main(cfg, output_dir=Path("outputs/image_transforms")): - - dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None) + dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None) output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1]) output_dir.mkdir(parents=True, exist_ok=True) @@ -27,11 +26,11 @@ def main(cfg, output_dir=Path("outputs/image_transforms")): for transform_name in transforms: for t in transforms: if t == transform_name: - cfg.image_transform[t].weight = 1 + cfg.image_transforms[t].weight = 1 else: - cfg.image_transform[t].weight = 0 + cfg.image_transforms[t].weight = 0 - transform = make_transforms(cfg.image_transform) + transform = make_image_transforms(cfg.image_transforms) img = transform(frame) to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) diff --git a/tests/scripts/save_image_transforms.py b/tests/scripts/save_image_transforms.py index 777dcd96..1162a307 100644 --- a/tests/scripts/save_image_transforms.py +++ b/tests/scripts/save_image_transforms.py @@ -15,7 +15,7 @@ to_pil = v2.ToPILImage() def main(repo_id): - dataset = LeRobotDataset(repo_id, transform=None) + dataset = LeRobotDataset(repo_id, image_transforms=None) output_dir = Path(ARTIFACT_DIR) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 8429997a..16f04f9a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -8,7 +8,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812 from PIL import Image from safetensors.torch import load_file -from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms +from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms from lerobot.common.datasets.utils import flatten_dict from lerobot.common.utils.utils import init_hydra_config, seeded_context from tests.utils import DEFAULT_CONFIG_PATH @@ -147,7 +147,7 @@ class TestMakeTransforms: config = self.config config[transform_key]["weight"] = 1 cfg = OmegaConf.create(config) - transform = make_transforms(cfg, to_dtype=torch.uint8) + transform = make_image_transforms(cfg, to_dtype=torch.uint8) # expected_t = self.transforms[transform_key] with seeded_context(seed):