Enable video_reader backend (#220)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Simon Alibert
2024-06-19 17:15:25 +02:00
committed by GitHub
parent 48951662f2
commit 2abef3bef9
11 changed files with 464 additions and 220 deletions

View File

@@ -48,6 +48,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
):
super().__init__()
self.repo_id = repo_id
@@ -69,6 +70,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info = load_info(repo_id, version, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
self.video_backend = video_backend if video_backend is not None else "pyav"
@property
def fps(self) -> int:
@@ -149,6 +151,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.video_frame_keys,
self.videos_dir,
self.tolerance_s,
self.video_backend,
)
if self.image_transforms is not None:
@@ -188,6 +191,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None,
info=None,
videos_dir=None,
video_backend=None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
@@ -210,6 +214,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.stats = stats
obj.info = info if info is not None else {}
obj.videos_dir = videos_dir
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj
@@ -228,6 +233,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
split: str = "train",
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
@@ -241,6 +247,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
split=split,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=video_backend,
)
for repo_id in repo_ids
]