Add data augmentation in LeRobotDataset (#234)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com> Co-authored-by: Remi Cadene <re.cadene@gmail.com>
This commit is contained in:
committed by
GitHub
parent
1cf050d412
commit
ff8f6aa6cd
@@ -46,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -54,7 +54,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
@@ -151,8 +151,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
if self.image_transforms is not None:
|
||||
for cam in self.camera_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
return item
|
||||
|
||||
@@ -168,7 +169,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@@ -202,7 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.transform = transform
|
||||
obj.image_transforms = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
@@ -225,7 +226,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -239,7 +240,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transform,
|
||||
image_transforms=image_transforms,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
@@ -274,7 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
@@ -380,6 +381,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
for data_key in self.disabled_data_keys:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
@@ -394,6 +396,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user