forked from tangger/lerobot
Merge remote-tracking branch 'origin/user/rcadene/2025_04_11_dataset_v3' into user/rcadene/2025_04_11_dataset_v3
This commit is contained in:
@@ -187,6 +187,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(path)
|
||||
# TODO(rcadene): update frame index
|
||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||
df = df.apply(update_data_func, axis=1)
|
||||
|
||||
|
||||
@@ -198,16 +198,15 @@ def convert_data(root, new_root):
|
||||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
if len(image_keys) != 0:
|
||||
raise NotImplementedError()
|
||||
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path):
|
||||
video_keys = get_video_keys(root)
|
||||
if len(video_keys) == 0:
|
||||
return None
|
||||
|
||||
video_keys = sorted(video_keys)
|
||||
|
||||
eps_metadata_per_cam = []
|
||||
@@ -285,24 +284,32 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
):
|
||||
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip(
|
||||
episodes_legacy_metadata.values(),
|
||||
episodes_metadata,
|
||||
episodes_videos,
|
||||
episodes_stats.values(),
|
||||
episodes_stats.keys(),
|
||||
strict=False,
|
||||
):
|
||||
ep_idx = ep_legacy_metadata["episode_index"]
|
||||
ep_idx_data = ep_metadata["episode_index"]
|
||||
ep_idx_video = ep_video["episode_index"]
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
episodes_stats_vals = list(episodes_stats.values())
|
||||
episodes_stats_keys = list(episodes_stats.keys())
|
||||
|
||||
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1:
|
||||
raise ValueError(
|
||||
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})."
|
||||
)
|
||||
for i in range(num_episodes):
|
||||
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
|
||||
ep_metadata = episodes_metadata[i]
|
||||
ep_stats = episodes_stats_vals[i]
|
||||
|
||||
ep_ids_set = {
|
||||
ep_legacy_metadata["episode_index"],
|
||||
ep_metadata["episode_index"],
|
||||
episodes_stats_keys[i],
|
||||
}
|
||||
|
||||
if episodes_videos is None:
|
||||
ep_video = {}
|
||||
else:
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
@@ -310,21 +317,20 @@ def generate_episode_metadata_dict(
|
||||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata):
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
|
||||
num_eps = len(episodes_legacy_metadata)
|
||||
num_eps_metadata = len(episodes_metadata)
|
||||
num_eps_video_metadata = len(episodes_video_metadata)
|
||||
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
|
||||
raise ValueError(
|
||||
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})."
|
||||
)
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
||||
Reference in New Issue
Block a user