forked from tangger/lerobot
WIP after Francesco discussion
This commit is contained in:
@@ -18,7 +18,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
concat_video_files,
|
concat_video_files,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
get_video_size_in_mb,
|
get_video_size_in_mb,
|
||||||
safe_write_dataframe_to_parquet,
|
to_parquet_with_hf_images,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
@@ -125,11 +125,14 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dst_meta.episodes = {}
|
||||||
|
|
||||||
# Process each dataset
|
# Process each dataset
|
||||||
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for src_meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
videos_idx = aggregate_videos(src_meta, dst_meta, videos_idx)
|
||||||
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
data_idx = aggregate_data(src_meta, dst_meta, data_idx)
|
||||||
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys)
|
|
||||||
|
meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx, video_keys, image_keys)
|
||||||
|
|
||||||
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
dst_meta.info["total_episodes"] += src_meta.total_episodes
|
||||||
dst_meta.info["total_frames"] += src_meta.total_frames
|
dst_meta.info["total_frames"] += src_meta.total_frames
|
||||||
@@ -205,7 +208,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
|||||||
file_idx,
|
file_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
if src_size + dst_size >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
# Size limit is reached, prepare new parquet file
|
# Size limit is reached, prepare new parquet file
|
||||||
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
|
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
|
||||||
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
|
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
|
||||||
@@ -219,7 +222,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx):
|
|||||||
# Update the existing parquet file with new rows
|
# Update the existing parquet file with new rows
|
||||||
aggr_df = pd.read_parquet(aggr_path)
|
aggr_df = pd.read_parquet(aggr_path)
|
||||||
df = pd.concat([aggr_df, df], ignore_index=True)
|
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||||
safe_write_dataframe_to_parquet(df, aggr_path, image_keys)
|
to_parquet_with_hf_images(df, aggr_path, dst_meta.image_keys)
|
||||||
|
|
||||||
return videos_idx
|
return videos_idx
|
||||||
|
|
||||||
@@ -238,17 +241,14 @@ def aggregate_data(src_meta, dst_meta, data_idx):
|
|||||||
df = pd.read_parquet(src_path)
|
df = pd.read_parquet(src_path)
|
||||||
df = update_data_df(df, src_meta, dst_meta)
|
df = update_data_df(df, src_meta, dst_meta)
|
||||||
|
|
||||||
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
|
data_idx = append_or_create_parquet_file(
|
||||||
chunk_index=data_idx["chunk"], file_index=data_idx["file"]
|
|
||||||
)
|
|
||||||
data_idx = write_parquet_safely(
|
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
dst_path,
|
|
||||||
data_idx,
|
data_idx,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
|
contains_images=len(dst_meta.image_keys) > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_idx
|
return data_idx
|
||||||
@@ -278,13 +278,9 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
# for k in video_keys:
|
# for k in video_keys:
|
||||||
# video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"]
|
# video_idx[k]["latest_duration"] += video_idx[k]["episode_duration"]
|
||||||
|
|
||||||
dst_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(
|
append_or_create_parquet_file(
|
||||||
chunk_index=meta_idx["chunk"], file_index=meta_idx["file"]
|
|
||||||
)
|
|
||||||
write_parquet_safely(
|
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
dst_path,
|
|
||||||
meta_idx,
|
meta_idx,
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
@@ -294,14 +290,14 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
return meta_idx
|
return meta_idx
|
||||||
|
|
||||||
|
|
||||||
def write_parquet_safely(
|
def append_or_create_parquet_file(
|
||||||
df: pd.DataFrame,
|
df: pd.DataFrame,
|
||||||
src_path: Path,
|
src_path: Path,
|
||||||
dst_path: Path,
|
|
||||||
idx: dict[str, int],
|
idx: dict[str, int],
|
||||||
max_mb: float,
|
max_mb: float,
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
default_path: str,
|
default_path: str,
|
||||||
|
contains_images: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
Safely appends or creates a Parquet file at dst_path based on size constraints.
|
||||||
@@ -309,7 +305,6 @@ def write_parquet_safely(
|
|||||||
Parameters:
|
Parameters:
|
||||||
df (pd.DataFrame): Data to write.
|
df (pd.DataFrame): Data to write.
|
||||||
src_path (Path): Path to source file (used to get size).
|
src_path (Path): Path to source file (used to get size).
|
||||||
dst_path (Path): Target path for writing.
|
|
||||||
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
idx (dict): Dictionary containing 'chunk' and 'file' indices.
|
||||||
max_mb (float): Maximum allowed file size in MB.
|
max_mb (float): Maximum allowed file size in MB.
|
||||||
chunk_size (int): Maximum number of files per chunk.
|
chunk_size (int): Maximum number of files per chunk.
|
||||||
@@ -318,6 +313,10 @@ def write_parquet_safely(
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Updated index dictionary.
|
dict: Updated index dictionary.
|
||||||
"""
|
"""
|
||||||
|
# Initial destination path
|
||||||
|
dst_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||||
|
chunk_index=idx["chunk"], file_index=idx["file"]
|
||||||
|
)
|
||||||
|
|
||||||
# If destination file doesn't exist, just write the new one
|
# If destination file doesn't exist, just write the new one
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
@@ -334,12 +333,16 @@ def write_parquet_safely(
|
|||||||
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size)
|
||||||
new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
new_path = dst_path.parent / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"])
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(new_path)
|
final_df = df
|
||||||
else:
|
else:
|
||||||
# Append to existing file
|
# Append to existing file
|
||||||
existing_df = pd.read_parquet(dst_path)
|
existing_df = pd.read_parquet(dst_path)
|
||||||
combined_df = pd.concat([existing_df, df], ignore_index=True)
|
final_df = pd.concat([existing_df, df], ignore_index=True)
|
||||||
combined_df.to_parquet(dst_path)
|
|
||||||
|
if contains_images:
|
||||||
|
to_parquet_with_hf_images(final_df, new_path)
|
||||||
|
else:
|
||||||
|
final_df.to_parquet(new_path)
|
||||||
|
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ from lerobot.common.datasets.utils import (
|
|||||||
load_nested_dataset,
|
load_nested_dataset,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_tasks,
|
load_tasks,
|
||||||
safe_write_dataframe_to_parquet,
|
to_parquet_with_hf_images,
|
||||||
update_chunk_file_indices,
|
update_chunk_file_indices,
|
||||||
validate_episode_buffer,
|
validate_episode_buffer,
|
||||||
validate_frame,
|
validate_frame,
|
||||||
@@ -1009,7 +1009,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
# Write the resulting dataframe from RAM to disk
|
# Write the resulting dataframe from RAM to disk
|
||||||
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
safe_write_dataframe_to_parquet(df, path, self.meta.image_keys)
|
if len(self.meta.image_keys) > 0:
|
||||||
|
to_parquet_with_hf_images(df, path)
|
||||||
|
else:
|
||||||
|
df.to_parquet(path)
|
||||||
|
|
||||||
# Update the Hugging Face dataset by reloading it.
|
# Update the Hugging Face dataset by reloading it.
|
||||||
# This process should be fast because only the latest Parquet file has been modified.
|
# This process should be fast because only the latest Parquet file has been modified.
|
||||||
|
|||||||
@@ -892,9 +892,9 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def safe_write_dataframe_to_parquet(df: pandas.DataFrame, path: Path, image_keys: list[str]):
|
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
|
||||||
if len(image_keys) > 0:
|
""" This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
|
||||||
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
This way, it can be loaded by HF dataset and correctly formated images are returned.
|
||||||
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
"""
|
||||||
else:
|
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
|
||||||
df.to_parquet(path)
|
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)
|
||||||
|
|||||||
Reference in New Issue
Block a user