In tests: Add use_videos=False by default, Create mp4 file if True, then fix test_datasets and test_aggregate (all passing)

This commit is contained in:
Remi Cadene
2025-05-12 15:37:02 +02:00
parent e88af0e588
commit e07cb52baa
7 changed files with 81 additions and 29 deletions

View File

@@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
concat_video_files,
get_parquet_file_size_in_mb,
get_video_size_in_mb,
safe_write_dataframe_to_parquet,
update_chunk_file_indices,
write_info,
write_stats,
@@ -97,6 +98,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
fps, robot_type, features = validate_all_metadata(all_metadata)
video_keys = [key for key in features if features[key]["dtype"] == "video"]
image_keys = [key for key in features if features[key]["dtype"] == "image"]
# Create resulting dataset folder
aggr_meta = LeRobotDatasetMetadata.create(
@@ -259,7 +261,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
# Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True)
df.to_parquet(aggr_path)
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
num_episodes += meta.total_episodes
num_frames += meta.total_frames

View File

@@ -63,6 +63,7 @@ from lerobot.common.datasets.utils import (
load_nested_dataset,
load_stats,
load_tasks,
safe_write_dataframe_to_parquet,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
@@ -1008,10 +1009,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
if len(self.meta.image_keys) > 0:
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.

View File

@@ -890,3 +890,11 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
if len(image_keys) > 0:
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)