forked from tangger/lerobot
Aggregate: Add concatenation
This commit is contained in:
@@ -9,9 +9,15 @@ from lerobot.common.datasets.compute_stats import aggregate_stats
|
|||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
|
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_DATA_PATH,
|
DEFAULT_DATA_PATH,
|
||||||
DEFAULT_EPISODES_PATH,
|
DEFAULT_EPISODES_PATH,
|
||||||
|
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||||
DEFAULT_VIDEO_PATH,
|
DEFAULT_VIDEO_PATH,
|
||||||
|
concat_video_files,
|
||||||
|
get_parquet_file_size_in_mb,
|
||||||
|
get_video_size_in_mb,
|
||||||
|
update_chunk_file_indices,
|
||||||
write_info,
|
write_info,
|
||||||
write_stats,
|
write_stats,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
@@ -41,17 +47,18 @@ def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
|||||||
return fps, robot_type, features
|
return fps, robot_type, features
|
||||||
|
|
||||||
|
|
||||||
def get_update_episode_and_task_func(episode_index_to_add, old_tasks, new_tasks):
|
def update_episode_and_task(df, episode_index_to_add, old_tasks, new_tasks):
|
||||||
def _update(row):
|
def _update(row):
|
||||||
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
row["episode_index"] = row["episode_index"] + episode_index_to_add
|
||||||
task = old_tasks.iloc[row["task_index"]].name
|
task = old_tasks.iloc[row["task_index"]].name
|
||||||
row["task_index"] = new_tasks.loc[task].task_index.item()
|
row["task_index"] = new_tasks.loc[task].task_index.item()
|
||||||
return row
|
return row
|
||||||
|
|
||||||
return _update
|
return df.apply(_update, axis=1)
|
||||||
|
|
||||||
|
|
||||||
def get_update_meta_func(
|
def update_meta_data(
|
||||||
|
df,
|
||||||
meta_chunk_index_to_add,
|
meta_chunk_index_to_add,
|
||||||
meta_file_index_to_add,
|
meta_file_index_to_add,
|
||||||
data_chunk_index_to_add,
|
data_chunk_index_to_add,
|
||||||
@@ -74,7 +81,7 @@ def get_update_meta_func(
|
|||||||
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
row["dataset_to_index"] = row["dataset_to_index"] + frame_index_to_add
|
||||||
return row
|
return row
|
||||||
|
|
||||||
return _update
|
return df.apply(_update, axis=1)
|
||||||
|
|
||||||
|
|
||||||
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path] = None, aggr_root=None):
|
||||||
@@ -117,6 +124,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
aggr_videos_file_idx = dict.fromkeys(video_keys, 0)
|
aggr_videos_file_idx = dict.fromkeys(video_keys, 0)
|
||||||
|
|
||||||
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
for meta in tqdm.tqdm(all_metadata, desc="Copy data and videos"):
|
||||||
|
# Aggregate episodes meta data
|
||||||
meta_chunk_file_ids = {
|
meta_chunk_file_ids = {
|
||||||
(c, f)
|
(c, f)
|
||||||
for c, f in zip(
|
for c, f in zip(
|
||||||
@@ -128,7 +136,8 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
for chunk_idx, file_idx in meta_chunk_file_ids:
|
for chunk_idx, file_idx in meta_chunk_file_ids:
|
||||||
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
df = pd.read_parquet(path)
|
df = pd.read_parquet(path)
|
||||||
update_meta_func = get_update_meta_func(
|
df = update_meta_data(
|
||||||
|
df,
|
||||||
aggr_meta_chunk_idx,
|
aggr_meta_chunk_idx,
|
||||||
aggr_meta_file_idx,
|
aggr_meta_file_idx,
|
||||||
aggr_data_chunk_idx,
|
aggr_data_chunk_idx,
|
||||||
@@ -137,20 +146,29 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
aggr_videos_file_idx,
|
aggr_videos_file_idx,
|
||||||
num_frames,
|
num_frames,
|
||||||
)
|
)
|
||||||
df = df.apply(update_meta_func, axis=1)
|
|
||||||
|
|
||||||
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(
|
aggr_path = aggr_root / DEFAULT_EPISODES_PATH.format(
|
||||||
chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx
|
chunk_index=aggr_meta_chunk_idx, file_index=aggr_meta_file_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if aggr_path.exists():
|
||||||
|
size_in_mb = get_parquet_file_size_in_mb(path)
|
||||||
|
aggr_size_in_mb = get_parquet_file_size_in_mb(aggr_path)
|
||||||
|
|
||||||
|
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
|
# Size limit is reached, prepare new parquet file
|
||||||
|
aggr_meta_chunk_idx, aggr_meta_file_idx = update_chunk_file_indices(
|
||||||
|
aggr_meta_chunk_idx, aggr_meta_file_idx, DEFAULT_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
aggr_df = pd.read_parquet(aggr_path)
|
||||||
|
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||||
|
|
||||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(aggr_path)
|
df.to_parquet(aggr_path)
|
||||||
|
|
||||||
aggr_meta_file_idx += 1
|
# Aggregate videos if any
|
||||||
if aggr_meta_file_idx >= DEFAULT_CHUNK_SIZE:
|
|
||||||
aggr_meta_file_idx = 0
|
|
||||||
aggr_meta_chunk_idx += 1
|
|
||||||
|
|
||||||
# cp videos
|
|
||||||
for key in video_keys:
|
for key in video_keys:
|
||||||
video_chunk_file_ids = {
|
video_chunk_file_ids = {
|
||||||
(c, f)
|
(c, f)
|
||||||
@@ -169,17 +187,32 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
chunk_index=aggr_videos_chunk_idx[key],
|
chunk_index=aggr_videos_chunk_idx[key],
|
||||||
file_index=aggr_videos_file_idx[key],
|
file_index=aggr_videos_file_idx[key],
|
||||||
)
|
)
|
||||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
if aggr_path.exists():
|
||||||
shutil.copy(str(path), str(aggr_path))
|
size_in_mb = get_video_size_in_mb(path)
|
||||||
|
aggr_size_in_mb = get_video_size_in_mb(aggr_path)
|
||||||
|
|
||||||
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
if aggr_size_in_mb + size_in_mb >= DEFAULT_VIDEO_FILE_SIZE_IN_MB:
|
||||||
# subprocess.Popen(copy_command, shell=True)
|
# Size limit is reached, prepare new parquet file
|
||||||
|
aggr_videos_chunk_idx[key], aggr_videos_file_idx[key] = update_chunk_file_indices(
|
||||||
|
aggr_videos_chunk_idx[key], aggr_videos_file_idx[key], DEFAULT_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
concat_video_files(
|
||||||
|
[aggr_path, path],
|
||||||
|
aggr_root,
|
||||||
|
key,
|
||||||
|
aggr_videos_chunk_idx[key],
|
||||||
|
aggr_videos_file_idx[key],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(str(path), str(aggr_path))
|
||||||
|
|
||||||
aggr_videos_file_idx[key] += 1
|
# copy_command = f"cp {video_path} {aggr_video_path} &"
|
||||||
if aggr_videos_file_idx[key] >= DEFAULT_CHUNK_SIZE:
|
# subprocess.Popen(copy_command, shell=True)
|
||||||
aggr_videos_file_idx[key] = 0
|
|
||||||
aggr_videos_chunk_idx[key] += 1
|
|
||||||
|
|
||||||
|
# Aggregate data
|
||||||
data_chunk_file_ids = {
|
data_chunk_file_ids = {
|
||||||
(c, f)
|
(c, f)
|
||||||
for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False)
|
for c, f in zip(meta.episodes["data/chunk_index"], meta.episodes["data/file_index"], strict=False)
|
||||||
@@ -188,20 +221,28 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
df = pd.read_parquet(path)
|
df = pd.read_parquet(path)
|
||||||
# TODO(rcadene): update frame index
|
# TODO(rcadene): update frame index
|
||||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
df = update_episode_and_task(df, num_episodes, meta.tasks, aggr_meta.tasks)
|
||||||
df = df.apply(update_data_func, axis=1)
|
|
||||||
|
|
||||||
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
aggr_path = aggr_root / DEFAULT_DATA_PATH.format(
|
||||||
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
chunk_index=aggr_data_chunk_idx, file_index=aggr_data_file_idx
|
||||||
)
|
)
|
||||||
|
if aggr_path.exists():
|
||||||
|
size_in_mb = get_parquet_file_size_in_mb(path)
|
||||||
|
aggr_size_in_mb = get_parquet_file_size_in_mb(aggr_path)
|
||||||
|
|
||||||
|
if aggr_size_in_mb + size_in_mb >= DEFAULT_DATA_FILE_SIZE_IN_MB:
|
||||||
|
# Size limit is reached, prepare new parquet file
|
||||||
|
aggr_data_chunk_idx, aggr_data_file_idx = update_chunk_file_indices(
|
||||||
|
aggr_data_chunk_idx, aggr_data_file_idx, DEFAULT_CHUNK_SIZE
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Update the existing parquet file with new rows
|
||||||
|
aggr_df = pd.read_parquet(aggr_path)
|
||||||
|
df = pd.concat([aggr_df, df], ignore_index=True)
|
||||||
|
|
||||||
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
aggr_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
df.to_parquet(aggr_path)
|
df.to_parquet(aggr_path)
|
||||||
|
|
||||||
aggr_data_file_idx += 1
|
|
||||||
if aggr_data_file_idx >= DEFAULT_CHUNK_SIZE:
|
|
||||||
aggr_data_file_idx = 0
|
|
||||||
aggr_data_chunk_idx += 1
|
|
||||||
|
|
||||||
num_episodes += meta.total_episodes
|
num_episodes += meta.total_episodes
|
||||||
num_frames += meta.total_frames
|
num_frames += meta.total_frames
|
||||||
|
|
||||||
@@ -209,6 +250,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||||
|
|
||||||
logging.info("write info")
|
logging.info("write info")
|
||||||
|
aggr_meta.info["total_tasks"] = len(aggr_meta.tasks)
|
||||||
aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata])
|
aggr_meta.info["total_episodes"] = sum([meta.total_episodes for meta in all_metadata])
|
||||||
aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata])
|
aggr_meta.info["total_frames"] = sum([meta.total_frames for meta in all_metadata])
|
||||||
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"}
|
aggr_meta.info["splits"] = {"train": f"0:{aggr_meta.total_episodes}"}
|
||||||
@@ -221,29 +263,21 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
init_logging()
|
init_logging()
|
||||||
repo_id = "cadene/droid"
|
aggr_repo_id = "cadene/aggregate_test"
|
||||||
aggr_repo_id = "cadene/droid"
|
aggr_root = Path(f"/tmp/{aggr_repo_id}")
|
||||||
datetime = "2025-02-22_11-23-54"
|
if aggr_root.exists():
|
||||||
|
shutil.rmtree(aggr_root)
|
||||||
|
|
||||||
# root = Path(f"/tmp/{repo_id}")
|
aggregate_datasets(
|
||||||
# if root.exists():
|
["lerobot/aloha_sim_transfer_cube_human", "lerobot/aloha_sim_insertion_human"],
|
||||||
# shutil.rmtree(root)
|
aggr_repo_id,
|
||||||
root = None
|
aggr_root=aggr_root,
|
||||||
|
|
||||||
# all_metadata = [LeRobotDatasetMetadata(f"{repo_id}_{datetime}_world_2048_rank_{rank}") for rank in range(2048)]
|
|
||||||
|
|
||||||
# aggregate_datasets(
|
|
||||||
# all_metadata,
|
|
||||||
# aggr_repo_id,
|
|
||||||
# root=root,
|
|
||||||
# )
|
|
||||||
|
|
||||||
aggr_dataset = LeRobotDataset(
|
|
||||||
repo_id=aggr_repo_id,
|
|
||||||
root=root,
|
|
||||||
)
|
)
|
||||||
aggr_dataset.push_to_hub(tags=["openx"])
|
|
||||||
|
|
||||||
# for meta in all_metadata:
|
aggr_dataset = LeRobotDataset(repo_id=aggr_repo_id, root=aggr_root)
|
||||||
# dataset = LeRobotDataset(repo_id=meta.repo_id, root=meta.root)
|
|
||||||
# dataset.push_to_hub(tags=["openx"])
|
for i in tqdm.tqdm(range(len(aggr_dataset))):
|
||||||
|
aggr_dataset[i]
|
||||||
|
pass
|
||||||
|
|
||||||
|
aggr_dataset.push_to_hub(tags=["openx"])
|
||||||
|
|||||||
Reference in New Issue
Block a user