Improve video benchmark (#282)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Simon Alibert
2024-07-09 20:20:25 +02:00
committed by GitHub
parent cc2f6e7404
commit e410e5d711
11 changed files with 985 additions and 772 deletions

View File

@@ -211,7 +211,7 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
video_path = raw_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
encode_video_frames(tmp_imgs_dir, video_path, fps, video_codec="libx264")
def _mock_download_raw(raw_dir, repo_id):
@@ -229,6 +229,23 @@ def _mock_download_raw(raw_dir, repo_id):
raise ValueError(repo_id)
def _mock_encode_video_frames(*args, **kwargs):
kwargs["video_codec"] = "libx264"
return encode_video_frames(*args, **kwargs)
def patch_encoder(raw_format, mocker):
format_module_map = {
"aloha_hdf5": "lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format.encode_video_frames",
"pusht_zarr": "lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format.encode_video_frames",
"xarm_pkl": "lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format.encode_video_frames",
"umi_zarr": "lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format.encode_video_frames",
}
if raw_format in format_module_map:
mocker.patch(format_module_map[raw_format], side_effect=_mock_encode_video_frames)
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
with pytest.raises(ValueError):
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
@@ -262,7 +279,10 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data, mocker):
# Patch `encode_video_frames` so that it uses 'libx264' instead of 'libsvtav1' for testing
patch_encoder(raw_format, mocker)
num_episodes = 3
tmpdir = Path(tmpdir)