forked from tangger/lerobot
WIP
This commit is contained in:
@@ -19,7 +19,7 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import make_image_transforms
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
@@ -72,7 +72,22 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
||||
image_transforms = None
|
||||
if cfg.image_transforms.enable:
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg.brightness.weight,
|
||||
brightness_min_max=cfg.brightness.min_max,
|
||||
contrast_weight=cfg.contrast.weight,
|
||||
contrast_min_max=cfg.contrast.min_max,
|
||||
saturation_weight=cfg.saturation.weight,
|
||||
saturation_min_max=cfg.saturation.min_max,
|
||||
hue_weight=cfg.hue.weight,
|
||||
hue_min_max=cfg.hue.min_max,
|
||||
sharpness_weight=cfg.sharpness.weight,
|
||||
sharpness_min_max=cfg.sharpness.min_max,
|
||||
max_num_transforms=cfg.max_num_transforms,
|
||||
random_order=cfg.random_order,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
|
||||
@@ -98,26 +98,60 @@ class RangeRandomSharpness(Transform):
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
||||
transforms_list = [
|
||||
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
||||
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
||||
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
|
||||
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
|
||||
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
|
||||
]
|
||||
transforms_weights = [
|
||||
cfg.brightness.weight,
|
||||
cfg.contrast.weight,
|
||||
cfg.saturation.weight,
|
||||
cfg.hue.weight,
|
||||
cfg.sharpness.weight,
|
||||
]
|
||||
def get_image_transforms(
|
||||
brightness_weight: float = 1.0,
|
||||
brightness_min_max: tuple[float, float] | None = None,
|
||||
contrast_weight: float = 1.0,
|
||||
contrast_min_max: tuple[float, float] | None = None,
|
||||
saturation_weight: float = 1.0,
|
||||
saturation_min_max: tuple[float, float] | None = None,
|
||||
hue_weight: float = 1.0,
|
||||
hue_min_max: tuple[float, float] | None = None,
|
||||
sharpness_weight: float = 1.0,
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
):
|
||||
|
||||
def check_value_error(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
if len(min_max) != 2:
|
||||
raise ValueError(f"`{name}_min_max` is expected to be a tuple of 2 dimensions, but {min_max} provided.")
|
||||
if weight < 0.:
|
||||
raise ValueError(f"`{name}_weight` is expected to be 0 or positive, but is negative ({weight}).")
|
||||
|
||||
transforms = RandomSubsetApply(
|
||||
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
|
||||
check_value_error("brightness", brightness_weight, brightness_min_max)
|
||||
check_value_error("contrast", contrast_weight, contrast_min_max)
|
||||
check_value_error("saturation", saturation_weight, saturation_min_max)
|
||||
check_value_error("hue", hue_weight, hue_min_max)
|
||||
check_value_error("sharpness", sharpness_weight, sharpness_min_max)
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if brightness_min_max is not None:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
if contrast_min_max is not None:
|
||||
weights.append(contrast_weight)
|
||||
transforms.append(v2.ColorJitter(contrast=contrast_min_max))
|
||||
if saturation_min_max is not None:
|
||||
weights.append(saturation_weight)
|
||||
transforms.append(v2.ColorJitter(saturation=saturation_min_max))
|
||||
if hue_min_max is not None:
|
||||
weights.append(hue_weight)
|
||||
transforms.append(v2.ColorJitter(hue=hue_min_max))
|
||||
if sharpness_min_max is not None:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(RangeRandomSharpness(**sharpness_min_max))
|
||||
|
||||
if max_num_transforms is None:
|
||||
n_subset = len(transforms)
|
||||
else:
|
||||
n_subset = min(len(transforms), max_num_transforms)
|
||||
|
||||
final_transforms = RandomSubsetApply(
|
||||
transforms, p=weights, n_subset=n_subset, random_order=random_order
|
||||
)
|
||||
|
||||
# return transforms
|
||||
# return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)])
|
||||
return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)])
|
||||
# TODO(rcadene, aliberts): add v2.ToDtype float16?
|
||||
return final_transforms
|
||||
|
||||
Reference in New Issue
Block a user