WIP after Francesco discussion

This commit is contained in:
Remi Cadene
2025-05-28 17:29:41 +02:00
parent f07887e8d1
commit 8746276d41
3 changed files with 34 additions and 28 deletions

View File

@@ -18,7 +18,7 @@ from lerobot.common.datasets.utils import (
concat_video_files,
get_parquet_file_size_in_mb,
get_video_size_in_mb,
safe_write_dataframe_to_parquet,
to_parquet_with_hf_images,
update_chunk_file_indices,
write_info,
write_stats,
@@ -125,11 +125,14 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
}
dst_meta.episodes = {}
# Process each dataset
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys, image_keys)
dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames
@@ -205,7 +208,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
file_idx,
)
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
if src_size + dst_size >= DEFAULT_DATA_FILE_SIZE_IN_MB:
# Size limit is reached, prepare new parquet file
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
@@ -219,7 +222,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
# Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True)
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
to_parquet_with_hf_images(df, aggr_path, dst_meta.image_keys)
return videos_idx
@@ -238,17 +241,14 @@ def aggregate_data(src_meta, dst_meta, data_idx):
df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta)
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=data_idx["chunk"], file_index=data_idx["file"]
)
data_idx = write_parquet_safely(
data_idx = append_or_create_parquet_file(
df,
src_path,
dst_path,
data_idx,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH,
contains_images=len(dst_meta.image_keys) > 0
)
return data_idx
@@ -278,13 +278,9 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
# for k in video_keys:
# video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"]
dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(
chunk_index=meta_idx["chunk"], file_index=meta_idx["file"]
)
write_parquet_safely(
append_or_create_parquet_file(
df,
src_path,
dst_path,
meta_idx,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_CHUNK_SIZE,
@@ -294,14 +290,14 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
return meta_idx
def write_parquet_safely(
def append_or_create_parquet_file(
df: pd.DataFrame,
src_path: Path,
dst_path: Path,
idx: dict[str, int],
max_mb: float,
chunk_size: int,
default_path: str,
contains_images: bool = False,
):
"""
Safely appends or creates a Parquet file at dst_path based on size constraints.
@@ -309,7 +305,6 @@ def write_parquet_safely(
Parameters:
df (pd.DataFrame): Data to write.
src_path (Path): Path to source file (used to get size).
dst_path (Path): Target path for writing.
idx (dict): Dictionary containing 'chunk' and 'file' indices.
max_mb (float): Maximum allowed file size in MB.
chunk_size (int): Maximum number of files per chunk.
@@ -318,6 +313,10 @@ def write_parquet_safely(
Returns:
dict: Updated index dictionary.
"""
# Initial destination path
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=idx["chunk"], file_index=idx["file"]
)
# If destination file doesn't exist, just write the new one
if not dst_path.exists():
@@ -334,12 +333,16 @@ def write_parquet_safely(
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
new_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(new_path)
final_df = df
else:
# Append to existing file
existing_df = pd.read_parquet(dst_path)
combined_df = pd.concat([existing_df, df], ignore_index=True)
combined_df.to_parquet(dst_path)
final_df = pd.concat([existing_df, df], ignore_index=True)
if contains_images:
to_parquet_with_hf_images(final_df, new_path)
else:
final_df.to_parquet(new_path)
return idx

View File

@@ -63,7 +63,7 @@ from lerobot.common.datasets.utils import (
load_nested_dataset,
load_stats,
load_tasks,
safe_write_dataframe_to_parquet,
to_parquet_with_hf_images,
update_chunk_file_indices,
validate_episode_buffer,
validate_frame,
@@ -1009,7 +1009,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
if len(self.meta.image_keys) > 0:
to_parquet_with_hf_images(df, path)
else:
df.to_parquet(path)
# Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified.

View File

@@ -892,9 +892,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
)
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
if len(image_keys) > 0:
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
else:
df.to_parquet(path)
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
""" This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formated images are returned.
"""
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)