Add data augmentation in LeRobotDataset (#234)

Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
This commit is contained in:
Marina Barannikov
2024-06-11 19:20:55 +02:00
committed by GitHub
parent 1cf050d412
commit ff8f6aa6cd
10 changed files with 811 additions and 12 deletions

View File

@@ -19,6 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
def resolve_delta_timestamps(cfg):
@@ -71,17 +72,36 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
# TODO(rcadene): add data augmentations
image_transforms = None
if cfg.training.image_transforms.enable:
image_transforms = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
)
if cfg.get("override_dataset_stats"):