This commit is contained in:
Remi Cadene
2024-06-08 12:10:04 +02:00
parent e52942a200
commit ca81b1d6f4
4 changed files with 305 additions and 159 deletions

View File

@@ -19,7 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import make_image_transforms
from lerobot.common.datasets.transforms import get_image_transforms
def resolve_delta_timestamps(cfg):
@@ -72,7 +72,22 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
image_transforms = None
if cfg.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(

View File

@@ -98,26 +98,60 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
]
transforms_weights = [
cfg.brightness.weight,
cfg.contrast.weight,
cfg.saturation.weight,
cfg.hue.weight,
cfg.sharpness.weight,
]
def get_image_transforms(
brightness_weight: float = 1.0,
brightness_min_max: tuple[float, float] | None = None,
contrast_weight: float = 1.0,
contrast_min_max: tuple[float, float] | None = None,
saturation_weight: float = 1.0,
saturation_min_max: tuple[float, float] | None = None,
hue_weight: float = 1.0,
hue_min_max: tuple[float, float] | None = None,
sharpness_weight: float = 1.0,
sharpness_min_max: tuple[float, float] | None = None,
max_num_transforms: int | None = None,
random_order: bool = False,
):
def check_value_error(name, weight, min_max):
if min_max is not None:
if len(min_max) != 2:
raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.")
if weight < 0.:
raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).")
transforms = RandomSubsetApply(
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
check_value_error("brightness", brightness_weight, brightness_min_max)
check_value_error("contrast", contrast_weight, contrast_min_max)
check_value_error("saturation", saturation_weight, saturation_min_max)
check_value_error("hue", hue_weight, hue_min_max)
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
weights = []
transforms = []
if brightness_min_max is not None:
weights.append(brightness_weight)
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
if contrast_min_max is not None:
weights.append(contrast_weight)
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
if saturation_min_max is not None:
weights.append(saturation_weight)
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
if hue_min_max is not None:
weights.append(hue_weight)
transforms.append(v2.ColorJitter(hue=hue_min_max))
if sharpness_min_max is not None:
weights.append(sharpness_weight)
transforms.append(RangeRandomSharpness(**sharpness_min_max))
if max_num_transforms is None:
n_subset = len(transforms)
else:
n_subset = min(len(transforms), max_num_transforms)
final_transforms = RandomSubsetApply(
transforms, p=weights, n_subset=n_subset, random_order=random_order
)
# return transforms
# return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)])
return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)])
# TODO(rcadene, aliberts): add v2.ToDtype float16?
return final_transforms