In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing)

This commit is contained in:
Remi Cadene
2025-05-12 15:37:02 +02:00
parent e88af0e588
commit e07cb52baa
7 changed files with 81 additions and 29 deletions

View File

@@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
concat_video_files, concat_video_files,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
get_video_size_in_mb, get_video_size_in_mb,
safe_write_dataframe_to_parquet,
update_chunk_file_indices, update_chunk_file_indices,
write_info, write_info,
write_stats, write_stats,
@@ -97,6 +98,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
fps, robot_type, features = validate_all_metadata(all_metadata) fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"] video_keys = [key for key in features if features[key]["dtype"] == "video"]
image_keys = [key for key in features if features[key]["dtype"] == "image"]
# Create resulting dataset folder # Create resulting dataset folder
aggr_meta = LeRobotDatasetMetadata.create( aggr_meta = LeRobotDatasetMetadata.create(
@@ -259,7 +261,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
# Update the existing parquet file with new rows # Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path) aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True) df = pd.concat([aggr_df, df], ignore_index=True)
df.to_parquet(aggr_path) safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
num_episodes += meta.total_episodes num_episodes += meta.total_episodes
num_frames += meta.total_frames num_frames += meta.total_frames

View File

@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
load_nested_dataset, load_nested_dataset,
load_stats, load_stats,
load_tasks, load_tasks,
safe_write_dataframe_to_parquet,
update_chunk_file_indices, update_chunk_file_indices,
validate_episode_buffer, validate_episode_buffer,
validate_frame, validate_frame,
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Write the resulting dataframe from RAM to disk # Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
if len(self.meta.image_keys) > 0: safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
# Update the Hugging Face dataset by reloading it. # Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified. # This process should be fast because only the latest Parquet file has been modified.

View File

@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}" f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}" f"In features not in episode_buffer: {set(features) - buffer_keys}"
) )
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
if len(image_keys) > 0:
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)

View File

@@ -71,7 +71,7 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create) dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init" root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init) dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
init_attr = set(vars(dataset_init).keys()) init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys()) create_attr = set(vars(dataset_create).keys())

View File

@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
}, },
} }
DUMMY_CAMERA_FEATURES = { DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None}, "phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
} }
DEFAULT_FPS = 30 DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = { DUMMY_VIDEO_INFO = {

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
import shutil
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Protocol from typing import Protocol
@@ -37,6 +38,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import ( from tests.fixtures.constants import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES,
@@ -95,7 +97,7 @@ def features_factory():
def _create_features( def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = False,
) -> dict: ) -> dict:
if use_videos: if use_videos:
camera_ft = { camera_ft = {
@@ -129,7 +131,7 @@ def info_factory(features_factory):
video_path: str = DEFAULT_VIDEO_PATH, video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES, motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES, camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True, use_videos: bool = False,
) -> dict: ) -> dict:
features = features_factory(motor_features, camera_features, use_videos) features = features_factory(motor_features, camera_features, use_videos)
return { return {
@@ -302,6 +304,38 @@ def episodes_factory(tasks_factory, stats_factory):
return _create_episodes return _create_episodes
@pytest.fixture(scope="session")
def create_videos(info_factory, img_array_factory):
def _create_video_directory(
root: Path,
info: dict | None = None,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
):
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
pil_img = PIL.Image.fromarray(img)
path = tmp_dir / f"frame-{frame_index:06d}.png"
pil_img.save(path)
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
shutil.rmtree(tmp_dir)
return _create_video_directory
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory): def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset( def _create_hf_dataset(
@@ -338,7 +372,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for key, ft in features.items(): for key, ft in features.items():
if ft["dtype"] == "image": if ft["dtype"] == "image":
robot_cols[key] = [ robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0]) img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
for _ in range(len(index_col)) for _ in range(len(index_col))
] ]
elif ft["shape"][0] > 1 and ft["dtype"] != "video": elif ft["shape"][0] > 1 and ft["dtype"] != "video":
@@ -439,6 +473,7 @@ def lerobot_dataset_factory(
hf_dataset: datasets.Dataset | None = None, hf_dataset: datasets.Dataset | None = None,
**kwargs, **kwargs,
) -> LeRobotDataset: ) -> LeRobotDataset:
# Instantiate objects
if info is None: if info is None:
info = info_factory( info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
@@ -448,19 +483,18 @@ def lerobot_dataset_factory(
if tasks is None: if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"]) tasks = tasks_factory(total_tasks=info["total_tasks"])
if episodes_metadata is None: if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory( episodes_metadata = episodes_factory(
features=info["features"], features=info["features"],
fps=info["fps"], fps=info["fps"],
total_episodes=info["total_episodes"], total_episodes=info["total_episodes"],
total_frames=info["total_frames"], total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks, tasks=tasks,
multi_task=multi_task, multi_task=multi_task,
) )
if not hf_dataset: if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"]) hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
# Write data on disk
mock_snapshot_download = mock_snapshot_download_factory( mock_snapshot_download = mock_snapshot_download_factory(
info=info, info=info,
stats=stats, stats=stats,

42
tests/fixtures/hub.py vendored
View File

@@ -22,6 +22,7 @@ from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH, DEFAULT_TASKS_PATH,
DEFAULT_VIDEO_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
) )
@@ -40,6 +41,7 @@ def mock_snapshot_download_factory(
create_episodes, create_episodes,
hf_dataset_factory, hf_dataset_factory,
create_hf_dataset, create_hf_dataset,
create_videos,
): ):
""" """
This factory allows to patch snapshot_download such that when called, it will create expected files rather This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -91,40 +93,48 @@ def mock_snapshot_download_factory(
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0), DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
] ]
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
allowed_files = filter_repo_objects( allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
) )
has_info = False request_info = False
has_tasks = False request_tasks = False
has_episodes = False request_episodes = False
has_stats = False request_stats = False
has_data = False request_data = False
request_videos = False
for rel_path in allowed_files: for rel_path in allowed_files:
if rel_path.startswith("meta/info.json"): if rel_path.startswith("meta/info.json"):
has_info = True request_info = True
elif rel_path.startswith("meta/stats"): elif rel_path.startswith("meta/stats"):
has_stats = True request_stats = True
elif rel_path.startswith("meta/tasks"): elif rel_path.startswith("meta/tasks"):
has_tasks = True request_tasks = True
elif rel_path.startswith("meta/episodes"): elif rel_path.startswith("meta/episodes"):
has_episodes = True request_episodes = True
elif rel_path.startswith("data/"): elif rel_path.startswith("data/"):
has_data = True request_data = True
elif rel_path.startswith("videos/"):
request_videos = True
else: else:
raise ValueError(f"{rel_path} not supported.") raise ValueError(f"{rel_path} not supported.")
if has_info: if request_info:
create_info(local_dir, info) create_info(local_dir, info)
if has_stats: if request_stats:
create_stats(local_dir, stats) create_stats(local_dir, stats)
if has_tasks: if request_tasks:
create_tasks(local_dir, tasks) create_tasks(local_dir, tasks)
if has_episodes: if request_episodes:
create_episodes(local_dir, episodes) create_episodes(local_dir, episodes)
# TODO(rcadene): create_videos? if request_data:
if has_data:
create_hf_dataset(local_dir, hf_dataset) create_hf_dataset(local_dir, hf_dataset)
if request_videos:
create_videos(root=local_dir, info=info)
return str(local_dir) return str(local_dir)