rename to image_transforms
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
|
||||
This script demonstrates how to implement torchvision image augmentation on an instance of a LeRobotDataset and how to show some transformed images.
|
||||
The transformations are passed to the dataset as an argument upon creation, and transforms are applied to the observation images before they are returned.
|
||||
"""
|
||||
|
||||
@@ -18,7 +18,7 @@ output_dir.mkdir(parents=True, exist_ok=True)
|
||||
repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(repo_id, transform=None)
|
||||
dataset = LeRobotDataset(repo_id, image_transforms=None)
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
@@ -41,7 +41,7 @@ transforms = v2.Compose(
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(repo_id, transform=transforms)
|
||||
transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import make_transforms
|
||||
from lerobot.common.datasets.transforms import make_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
@@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None
|
||||
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
||||
@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.transform = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
@@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
@@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.transform = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
||||
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
||||
transforms_list = [
|
||||
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
||||
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
||||
|
||||
@@ -58,10 +58,10 @@ wandb:
|
||||
project: lerobot
|
||||
notes: ""
|
||||
|
||||
image_transform:
|
||||
image_transforms:
|
||||
# brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class
|
||||
enable: true
|
||||
# A subset of these transforms will be applied for each batch. This is the maximum size of that subset.
|
||||
# A subset of these transforms will be applied for each batch. This is the maximum size of that subset.
|
||||
max_num_transforms: 3
|
||||
# Torchvision suggest applying the transforms in the following order : brightness, contrast, saturation, hue
|
||||
# sharpness can be applied at any time before or after (we choose after).
|
||||
|
||||
@@ -4,14 +4,13 @@ import hydra
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import make_transforms
|
||||
from lerobot.common.datasets.transforms import make_image_transforms
|
||||
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def main(cfg, output_dir=Path("outputs/image_transforms")):
|
||||
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
|
||||
|
||||
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -27,11 +26,11 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
|
||||
for transform_name in transforms:
|
||||
for t in transforms:
|
||||
if t == transform_name:
|
||||
cfg.image_transform[t].weight = 1
|
||||
cfg.image_transforms[t].weight = 1
|
||||
else:
|
||||
cfg.image_transform[t].weight = 0
|
||||
cfg.image_transforms[t].weight = 0
|
||||
|
||||
transform = make_transforms(cfg.image_transform)
|
||||
transform = make_image_transforms(cfg.image_transforms)
|
||||
img = transform(frame)
|
||||
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ to_pil = v2.ToPILImage()
|
||||
|
||||
|
||||
def main(repo_id):
|
||||
dataset = LeRobotDataset(repo_id, transform=None)
|
||||
dataset = LeRobotDataset(repo_id, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
@@ -147,7 +147,7 @@ class TestMakeTransforms:
|
||||
config = self.config
|
||||
config[transform_key]["weight"] = 1
|
||||
cfg = OmegaConf.create(config)
|
||||
transform = make_transforms(cfg, to_dtype=torch.uint8)
|
||||
transform = make_image_transforms(cfg, to_dtype=torch.uint8)
|
||||
|
||||
# expected_t = self.transforms[transform_key]
|
||||
with seeded_context(seed):
|
||||
|
||||
Reference in New Issue
Block a user