Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3

This commit is contained in:
Remi Cadene
2025-04-23 09:16:37 +00:00
5 changed files with 95 additions and 75 deletions

View File

@@ -187,6 +187,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
for chunk_idx, file_idx in data_chunk_file_ids:
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path)
# TODO(rcadene): update frame index
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
df = df.apply(update_data_func, axis=1)

View File

@@ -198,16 +198,15 @@ def convert_data(root, new_root):
def get_video_keys(root):
info = load_info(root)
features = info["features"]
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
if len(image_keys) != 0:
raise NotImplementedError()
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
return video_keys
def convert_videos(root: Path, new_root: Path):
video_keys = get_video_keys(root)
if len(video_keys) == 0:
return None
video_keys = sorted(video_keys)
eps_metadata_per_cam = []
@@ -285,24 +284,32 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
def generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
):
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip(
episodes_legacy_metadata.values(),
episodes_metadata,
episodes_videos,
episodes_stats.values(),
episodes_stats.keys(),
strict=False,
):
ep_idx = ep_legacy_metadata["episode_index"]
ep_idx_data = ep_metadata["episode_index"]
ep_idx_video = ep_video["episode_index"]
num_episodes = len(episodes_metadata)
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
episodes_stats_vals = list(episodes_stats.values())
episodes_stats_keys = list(episodes_stats.keys())
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1:
raise ValueError(
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})."
)
for i in range(num_episodes):
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
ep_metadata = episodes_metadata[i]
ep_stats = episodes_stats_vals[i]
ep_ids_set = {
ep_legacy_metadata["episode_index"],
ep_metadata["episode_index"],
episodes_stats_keys[i],
}
if episodes_videos is None:
ep_video = {}
else:
ep_video = episodes_videos[i]
ep_ids_set.add(ep_video["episode_index"])
if len(ep_ids_set) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
ep_dict["meta/episodes/chunk_index"] = 0
@@ -310,21 +317,20 @@ def generate_episode_metadata_dict(
yield ep_dict
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata):
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
episodes_legacy_metadata = legacy_load_episodes(root)
episodes_stats = legacy_load_episodes_stats(root)
num_eps = len(episodes_legacy_metadata)
num_eps_metadata = len(episodes_metadata)
num_eps_video_metadata = len(episodes_video_metadata)
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
raise ValueError(
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})."
)
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
if episodes_video_metadata is not None:
num_eps_set.add(len(episodes_video_metadata))
if len(num_eps_set) != 1:
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
)
)
write_episodes(ds_episodes, new_root)