From 960589849fe9d365e1a4c9f8d57838c5a217fd8e Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 28 May 2024 20:37:12 +0000 Subject: [PATCH] Fix aloha_dora_format --- .../push_dataset_to_hub/aloha_dora_format.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py index 20c3e798e..b897bd138 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py @@ -18,6 +18,7 @@ Contains utilities to process raw data format from dora-record """ import logging +import re from pathlib import Path import pandas as pd @@ -50,9 +51,7 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): df = reference_df for path in raw_dir.glob("*.parquet"): key = path.stem # action or observation.state or ... - if key == reference_key: - continue - elif "failed_episode_index" in key: + if key == reference_key or "failed_episode_index" in key: continue modality_df = pd.read_parquet(path) modality_df = modality_df[["timestamp_utc", key]] @@ -64,15 +63,28 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): tolerance=pd.Timedelta(f"{1/fps} seconds"), ) - # Remove rows with a NaN in any column. It can happened during the first frames of an episode, - # because some cameras didnt start recording yet. - df = df.dropna(axis=0) - # Remove rows with episode_index -1 which indicates data that correspond to in-between episodes df = df[df["episode_index"] != -1] + image_keys = [key for key in df if "observation.images." in key] + + def get_episode_index(row): + episode_index_per_cam = {} + for key in image_keys: + path = row[key][0]["path"] + match = re.search(r"_(\d{6}).mp4", path) + if not match: + raise ValueError(path) + episode_index = int(match.group(1)) + episode_index_per_cam[key] = episode_index + assert ( + len(set(episode_index_per_cam.values())) == 1 + ), f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" + return episode_index + + df["episode_index"] = df.apply(get_episode_index, axis=1) + # dora only use arrays, so single values are encapsulated into a list - df["episode_index"] = df["episode_index"].map(lambda x: x[0]) df["frame_index"] = df.groupby("episode_index").cumcount() df = df.reset_index() df["index"] = df.index