WIP after Francesco discussion

This commit is contained in:
Remi Cadene
2025-05-28 17:29:41 +02:00
parent f07887e8d1
commit 8746276d41
3 changed files with 34 additions and 28 deletions

View File

@@ -18,7 +18,7 @@ from lerobot.common.datasets.utils import (
concat_video_files, concat_video_files,
get_parquet_file_size_in_mb, get_parquet_file_size_in_mb,
get_video_size_in_mb, get_video_size_in_mb,
safe_write_dataframe_to_parquet, to_parquet_with_hf_images,
update_chunk_file_indices, update_chunk_file_indices,
write_info, write_info,
write_stats, write_stats,
@@ -125,11 +125,14 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
} }
dst_meta.episodes = {}
# Process each dataset # Process each dataset
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"): for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx) videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
data_idx = aggregate_data(src_meta, dst_meta, data_idx) data_idx = aggregate_data(src_meta, dst_meta, data_idx)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys)
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys, image_keys)
dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_episodes"] += src_meta.total_episodes
dst_meta.info["total_frames"] += src_meta.total_frames dst_meta.info["total_frames"] += src_meta.total_frames
@@ -205,7 +208,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
file_idx, file_idx,
) )
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB: if src_size + dst_size >= DEFAULT_DATA_FILE_SIZE_IN_MB:
# Size limit is reached, prepare new parquet file # Size limit is reached, prepare new parquet file
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices( aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
@@ -219,7 +222,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
# Update the existing parquet file with new rows # Update the existing parquet file with new rows
aggr_df = pd.read_parquet(aggr_path) aggr_df = pd.read_parquet(aggr_path)
df = pd.concat([aggr_df, df], ignore_index=True) df = pd.concat([aggr_df, df], ignore_index=True)
safe_write_dataframe_to_parquet(df, aggr_path, image_keys) to_parquet_with_hf_images(df, aggr_path, dst_meta.image_keys)
return videos_idx return videos_idx
@@ -238,17 +241,14 @@ def aggregate_data(src_meta, dst_meta, data_idx):
df = pd.read_parquet(src_path) df = pd.read_parquet(src_path)
df = update_data_df(df, src_meta, dst_meta) df = update_data_df(df, src_meta, dst_meta)
dst_path = aggr_root / DEFAULT_DATA_PATH.format( data_idx = append_or_create_parquet_file(
chunk_index=data_idx["chunk"], file_index=data_idx["file"]
)
data_idx = write_parquet_safely(
df, df,
src_path, src_path,
dst_path,
data_idx, data_idx,
DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_PATH, DEFAULT_DATA_PATH,
contains_images=len(dst_meta.image_keys) > 0
) )
return data_idx return data_idx
@@ -278,13 +278,9 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
# for k in video_keys: # for k in video_keys:
# video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"] # video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"]
dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format( append_or_create_parquet_file(
chunk_index=meta_idx["chunk"], file_index=meta_idx["file"]
)
write_parquet_safely(
df, df,
src_path, src_path,
dst_path,
meta_idx, meta_idx,
DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_SIZE,
@@ -294,14 +290,14 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
return meta_idx return meta_idx
def write_parquet_safely( def append_or_create_parquet_file(
df: pd.DataFrame, df: pd.DataFrame,
src_path: Path, src_path: Path,
dst_path: Path,
idx: dict[str, int], idx: dict[str, int],
max_mb: float, max_mb: float,
chunk_size: int, chunk_size: int,
default_path: str, default_path: str,
contains_images: bool = False,
): ):
""" """
Safely appends or creates a Parquet file at dst_path based on size constraints. Safely appends or creates a Parquet file at dst_path based on size constraints.
@@ -309,7 +305,6 @@ def write_parquet_safely(
Parameters: Parameters:
df (pd.DataFrame): Data to write. df (pd.DataFrame): Data to write.
src_path (Path): Path to source file (used to get size). src_path (Path): Path to source file (used to get size).
dst_path (Path): Target path for writing.
idx (dict): Dictionary containing 'chunk' and 'file' indices. idx (dict): Dictionary containing 'chunk' and 'file' indices.
max_mb (float): Maximum allowed file size in MB. max_mb (float): Maximum allowed file size in MB.
chunk_size (int): Maximum number of files per chunk. chunk_size (int): Maximum number of files per chunk.
@@ -318,6 +313,10 @@ def write_parquet_safely(
Returns: Returns:
dict: Updated index dictionary. dict: Updated index dictionary.
""" """
# Initial destination path
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
chunk_index=idx["chunk"], file_index=idx["file"]
)
# If destination file doesn't exist, just write the new one # If destination file doesn't exist, just write the new one
if not dst_path.exists(): if not dst_path.exists():
@@ -334,12 +333,16 @@ def write_parquet_safely(
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
new_path.parent.mkdir(parents=True, exist_ok=True) new_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(new_path) final_df = df
else: else:
# Append to existing file # Append to existing file
existing_df = pd.read_parquet(dst_path) existing_df = pd.read_parquet(dst_path)
combined_df = pd.concat([existing_df, df], ignore_index=True) final_df = pd.concat([existing_df, df], ignore_index=True)
combined_df.to_parquet(dst_path)
if contains_images:
to_parquet_with_hf_images(final_df, new_path)
else:
final_df.to_parquet(new_path)
return idx return idx

View File

@@ -63,7 +63,7 @@ from lerobot.common.datasets.utils import (
load_nested_dataset, load_nested_dataset,
load_stats, load_stats,
load_tasks, load_tasks,
safe_write_dataframe_to_parquet, to_parquet_with_hf_images,
update_chunk_file_indices, update_chunk_file_indices,
validate_episode_buffer, validate_episode_buffer,
validate_frame, validate_frame,
@@ -1009,7 +1009,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Write the resulting dataframe from RAM to disk # Write the resulting dataframe from RAM to disk
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys) if len(self.meta.image_keys) > 0:
to_parquet_with_hf_images(df, path)
else:
df.to_parquet(path)
# Update the Hugging Face dataset by reloading it. # Update the Hugging Face dataset by reloading it.
# This process should be fast because only the latest Parquet file has been modified. # This process should be fast because only the latest Parquet file has been modified.

View File

@@ -892,9 +892,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
) )
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]): def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
if len(image_keys) > 0: """ This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only This way, it can be loaded by HF dataset and correctly formated images are returned.
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) """
else: # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
df.to_parquet(path) datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)