rename to image_transforms

This commit is contained in:
Simon Alibert
2024-06-06 16:50:22 +00:00
parent a86f387554
commit c45dd8f848
8 changed files with 23 additions and 24 deletions

View File

@@ -18,7 +18,7 @@ output_dir.mkdir(parents=True, exist_ok=True)
repo_id = "lerobot/aloha_static_tape" repo_id = "lerobot/aloha_static_tape"
# Create a LeRobotDataset with no transformations # Create a LeRobotDataset with no transformations
dataset = LeRobotDataset(repo_id, transform=None) dataset = LeRobotDataset(repo_id, image_transforms=None)
# Get the index of the first observation in the first episode # Get the index of the first observation in the first episode
first_idx = dataset.episode_data_index["from"][0].item() first_idx = dataset.episode_data_index["from"][0].item()
@@ -41,7 +41,7 @@ transforms = v2.Compose(
) )
# Create another LeRobotDataset with the defined transformations # Create another LeRobotDataset with the defined transformations
transformed_dataset = LeRobotDataset(repo_id, transform=transforms) transformed_dataset = LeRobotDataset(repo_id, image_transforms=transforms)
# Get a frame from the transformed dataset # Get a frame from the transformed dataset
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]] transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]

View File

@@ -19,7 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import make_transforms from lerobot.common.datasets.transforms import make_image_transforms
def resolve_delta_timestamps(cfg): def resolve_delta_timestamps(cfg):
@@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg) resolve_delta_timestamps(cfg)
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
if isinstance(cfg.dataset_repo_id, str): if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform, image_transforms=image_transforms,
) )
else: else:
dataset = MultiLeRobotDataset( dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform, image_transforms=image_transforms,
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):

View File

@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION, version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
transform: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.transform = transform self.transform = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided # load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
@@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION, version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR, root: Path | None = DATA_DIR,
split: str = "train", split: str = "train",
transform: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
): ):
super().__init__() super().__init__()
@@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root=root, root=root,
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
transform=transform, image_transforms=image_transforms,
) )
for repo_id in repo_ids for repo_id in repo_ids
] ]
@@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version self.version = version
self.root = root self.root = root
self.split = split self.split = split
self.transform = transform self.transform = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets) self.stats = aggregate_stats(self._datasets)

View File

@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32): def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [ transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),

View File

@@ -58,7 +58,7 @@ wandb:
project: lerobot project: lerobot
notes: "" notes: ""
image_transform: image_transforms:
# brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class # brigthness, contrast, saturation and hue are instances of torchvision Colorjitter, sharpness is an instance of custom class
enable: true enable: true
# A subset of these transforms will be applied for each batch. This is the maximum size of that subset. # A subset of these transforms will be applied for each batch. This is the maximum size of that subset.

View File

@@ -4,14 +4,13 @@ import hydra
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import make_transforms from lerobot.common.datasets.transforms import make_image_transforms
to_pil = ToPILImage() to_pil = ToPILImage()
def main(cfg, output_dir=Path("outputs/image_transforms")): def main(cfg, output_dir=Path("outputs/image_transforms")):
dataset = LeRobotDataset(cfg.dataset_repo_id, image_transforms=None)
dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1]) output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@@ -27,11 +26,11 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
for transform_name in transforms: for transform_name in transforms:
for t in transforms: for t in transforms:
if t == transform_name: if t == transform_name:
cfg.image_transform[t].weight = 1 cfg.image_transforms[t].weight = 1
else: else:
cfg.image_transform[t].weight = 0 cfg.image_transforms[t].weight = 0
transform = make_transforms(cfg.image_transform) transform = make_image_transforms(cfg.image_transforms)
img = transform(frame) img = transform(frame)
to_pil(img).save(output_dir / f"{transform_name}.png", quality=100) to_pil(img).save(output_dir / f"{transform_name}.png", quality=100)

View File

@@ -15,7 +15,7 @@ to_pil = v2.ToPILImage()
def main(repo_id): def main(repo_id):
dataset = LeRobotDataset(repo_id, transform=None) dataset = LeRobotDataset(repo_id, image_transforms=None)
output_dir = Path(ARTIFACT_DIR) output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -8,7 +8,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_image_transforms
from lerobot.common.datasets.utils import flatten_dict from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH from tests.utils import DEFAULT_CONFIG_PATH
@@ -147,7 +147,7 @@ class TestMakeTransforms:
config = self.config config = self.config
config[transform_key]["weight"] = 1 config[transform_key]["weight"] = 1
cfg = OmegaConf.create(config) cfg = OmegaConf.create(config)
transform = make_transforms(cfg, to_dtype=torch.uint8) transform = make_image_transforms(cfg, to_dtype=torch.uint8)
# expected_t = self.transforms[transform_key] # expected_t = self.transforms[transform_key]
with seeded_context(seed): with seeded_context(seed):