From 63a5d0be39d6eb4a7ded76f9cf07503dad5093fa Mon Sep 17 00:00:00 2001 From: Thomas Wolf Date: Tue, 4 Jun 2024 11:40:46 +0200 Subject: [PATCH] fix nans --- .../push_dataset_to_hub/aloha_dora_format.py | 35 ++++++++++++++++--- lerobot/common/datasets/utils.py | 9 ++--- lerobot/common/logger.py | 2 +- lerobot/scripts/eval.py | 5 ++- lerobot/scripts/train.py | 2 +- 5 files changed, 41 insertions(+), 12 deletions(-) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py index 4a21bc2d..edadd641 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py @@ -78,15 +78,29 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): image_keys = [key for key in df if "observation.images." in key] + num_unaligned_images = 0 + max_episode = 0 + def get_episode_index(row): + nonlocal num_unaligned_images + nonlocal max_episode episode_index_per_cam = {} for key in image_keys: + if isinstance(row[key], float): + num_unaligned_images += 1 + return float("nan") path = row[key][0]["path"] match = re.search(r"_(\d{6}).mp4", path) if not match: raise ValueError(path) episode_index = int(match.group(1)) episode_index_per_cam[key] = episode_index + + if episode_index > max_episode: + assert episode_index - max_episode == 1 + max_episode = episode_index + else: + assert episode_index == max_episode if len(set(episode_index_per_cam.values())) != 1: raise ValueError( f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}" @@ -111,11 +125,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): del df["timestamp_utc"] # sanity check - has_nan = df.isna().any().any() - if has_nan: - raise ValueError("Dataset contains Nan values.") + num_rows_with_nan = df.isna().any(axis=1).sum() + assert ( + num_rows_with_nan == num_unaligned_images + ), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images." + if num_unaligned_images > max_episode * 2: + # We allow a few unaligned images, typically at the beginning and end of the episodes for instance + # but if there are too many, we raise an error to avoid large chunks of missing data + raise ValueError( + f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. " + f"Check the timestamps of the cameras." + ) + + # Drop rows with NaN values now that we double checked and convert episode_index to int + df = df.dropna() + df["episode_index"] = df["episode_index"].astype(int) # sanity check episode indices go from 0 to n-1 + assert df["episode_index"].max() == max_episode ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] expected_ep_ids = list(range(df["episode_index"].max() + 1)) if ep_ids != expected_ep_ids: @@ -214,8 +241,6 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru if fps is None: fps = 30 - else: - raise NotImplementedError() if not video: raise NotImplementedError() diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index cb2fee95..23d828fa 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -243,10 +243,11 @@ def load_previous_and_future_frames( is_pad = min_ > tolerance_s # check violated query timestamps are all outside the episode range - assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range." - "This might be due to synchronization issues with timestamps during data collection." - ) + if not ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(): + raise ValueError( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range." + "This might be due to synchronization issues with timestamps during data collection." + ) # get dataset indices corresponding to frames to be loaded data_ids = ep_data_ids[argmin_] diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 853acbc3..ca9ce113 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -189,7 +189,7 @@ class Logger: training_state["scheduler"] = scheduler.state_dict() torch.save(training_state, save_dir / self.training_state_file_name) - def save_checkpont( + def save_checkpoint( self, train_step: int, policy: Policy, diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 784e9fc6..5dbef74b 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -164,7 +164,10 @@ def rollout( # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # available of none of the envs finished. if "final_info" in info: - successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + successes = [ + info["is_success"] if info is not None and "is_success" in info else False + for info in info["final_info"] + ] else: successes = [False] * env.num_envs diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 860412bd..83a511e5 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -345,7 +345,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill). - logger.save_checkpont( + logger.save_checkpoint( step, policy, optimizer,