forked from tangger/lerobot
rename to image_transforms
This commit is contained in:
@@ -19,7 +19,7 @@ import torch
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.transforms import make_transforms
|
||||
from lerobot.common.datasets.transforms import make_image_transforms
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
@@ -72,21 +72,21 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
transform = make_transforms(cfg.image_transform) if cfg.image_transform.enable else None
|
||||
image_transforms = make_image_transforms(cfg.image_transforms) if cfg.image_transforms.enable else None
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
||||
@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.transform = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
@@ -226,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -240,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
@@ -275,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.transform = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
||||
def make_image_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
||||
transforms_list = [
|
||||
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
||||
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
||||
|
||||
Reference in New Issue
Block a user