rename to image_transforms

This commit is contained in:
Simon Alibert
2024-06-06 16:50:22 +00:00
parent a86f387554
commit c45dd8f848
8 changed files with 23 additions and 24 deletions

View File

@@ -19,7 +19,7 @@ import torch
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.transforms import make_transforms
from lerobot.common.datasets.transforms import make_image_transforms
def resolve_delta_timestamps(cfg):
@@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
resolve_delta_timestamps(cfg)
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform,
image_transforms=image_transforms,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
transform=transform,
image_transforms=image_transforms,
)
if cfg.get("override_dataset_stats"):

View File

@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.transform = image_transforms
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
# TODO(rcadene, aliberts): implement faster transfer
@@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
image_transforms=image_transforms,
)
for repo_id in repo_ids
]
@@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.version = version
self.root = root
self.split = split
self.transform = transform
self.transform = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)

View File

@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),