In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing)

This commit is contained in:
Remi Cadene
2025-05-12 15:37:02 +02:00
parent e88af0e588
commit e07cb52baa
7 changed files with 81 additions and 29 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import shutil
from functools import partial
from pathlib import Path
from typing import Protocol
@@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features,
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@@ -95,7 +97,7 @@ def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
use_videos: bool = False,
) -> dict:
if use_videos:
camera_ft = {
@@ -129,7 +131,7 @@ def info_factory(features_factory):
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
use_videos: bool = False,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
return {
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
return _create_episodes
@pytest.fixture(scope="session")
def create_videos(info_factory, img_array_factory):
def _create_video_directory(
root: Path,
info: dict | None = None,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
):
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
pil_img = PIL.Image.fromarray(img)
path = tmp_dir / f"frame-{frame_index:06d}.png"
pil_img.save(path)
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
shutil.rmtree(tmp_dir)
return _create_video_directory
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
@@ -439,6 +473,7 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
# Instantiate objects
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
@@ -448,19 +483,18 @@ def lerobot_dataset_factory(
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
# Write data on disk
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,