Improve V3 aggregate implementation (#2077)

* fix return type

* improve apply with vertorize op

* Update src/lerobot/datasets/aggregate.py

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Qizhi Chen
2025-09-29 17:18:54 +08:00
committed by GitHub
parent f59eb54f5c
commit 90684a9690
2 changed files with 22 additions and 29 deletions

View File

@@ -93,14 +93,13 @@ def update_data_df(df, src_meta, dst_meta):
pd.DataFrame: Updated DataFrame with adjusted indices. pd.DataFrame: Updated DataFrame with adjusted indices.
""" """
def _update(row): df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] df["index"] = df["index"] + dst_meta.info["total_frames"]
row["index"] = row["index"] + dst_meta.info["total_frames"]
task = src_meta.tasks.iloc[row["task_index"]].name
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
return row
return df.apply(_update, axis=1) src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
return df
def update_meta_data( def update_meta_data(
@@ -126,27 +125,21 @@ def update_meta_data(
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps. pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
""" """
def _update(row): df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"] df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"] df["data/file_index"] = df["data/file_index"] + data_idx["file"]
row["data/file_index"] = row["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items():
for key, video_idx in videos_idx.items(): df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"] df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"] df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
row[f"videos/{key}/from_timestamp"] = ( df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
row[f"videos/{key}/to_timestamp"] = (
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
)
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"] df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"] df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"] df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
return row
return df.apply(_update, axis=1) return df
def aggregate_datasets( def aggregate_datasets(

View File

@@ -1027,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding) # Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0) self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None): def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
""" """
Batch save videos for multiple episodes. Batch save videos for multiple episodes.
@@ -1153,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
} }
return metadata return metadata
def _save_episode_video(self, video_key: str, episode_index: int): def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video # Encode episode frames into a temporary video
ep_path = self._encode_temporary_episode_video(video_key, episode_index) ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_video_size_in_mb(ep_path) ep_size_in_mb = get_video_size_in_mb(ep_path)
@@ -1258,7 +1258,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.wait_until_done() self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict: def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
""" """
Use ffmpeg to convert frames stored as png into mp4 videos. Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,