diff --git a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py index ff4689efa..900a43a4f 100644 --- a/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py +++ b/src/lerobot/datasets/v30/augment_dataset_quantile_stats.py @@ -40,10 +40,12 @@ from pathlib import Path import numpy as np import torch +from huggingface_hub import HfApi +from requests import HTTPError from tqdm import tqdm from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats -from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.datasets.utils import write_stats from lerobot.utils.utils import init_logging @@ -85,13 +87,27 @@ def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict: start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"] end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"] + collected_data: dict[str, list] = {} + for idx in range(start_idx, end_idx): + item = dataset[idx] + for key, value in item.items(): + if key not in dataset.features: + continue + + if key not in collected_data: + collected_data[key] = [] + collected_data[key].append(value) + ep_stats = {} - for key, data in dataset.hf_dataset[start_idx:end_idx].items(): + for key, data_list in collected_data.items(): if dataset.features[key]["dtype"] == "string": continue - data = torch.stack(data).cpu().numpy() + data = torch.stack(data_list).cpu().numpy() if dataset.features[key]["dtype"] in ["image", "video"]: + if data.dtype == np.uint8: + data = data.astype(np.float32) / 255.0 + axes_to_reduce = (0, 2, 3) keepdims = True else: @@ -103,12 +119,9 @@ def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict: ) if dataset.features[key]["dtype"] in ["image", "video"]: - for k, v in ep_stats[key].items(): - if dataset.features[key]["dtype"] == "video": - v = v / 255.0 - if k != "count": - v = np.squeeze(v, axis=0) - ep_stats[key][k] = v + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + } return ep_stats @@ -121,25 +134,39 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic Returns: Dictionary containing aggregated statistics with quantiles + + Note: + Video decoding operations are not thread-safe, so we process episodes sequentially + when video keys are present. For datasets without videos, we use parallel processing + with ThreadPoolExecutor for better performance. """ logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes") episode_stats_list = [] - max_workers = min(dataset.num_episodes, 16) + has_videos = len(dataset.meta.video_keys) > 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_episode = { - executor.submit(process_single_episode, dataset, episode_idx): episode_idx - for episode_idx in range(dataset.num_episodes) - } + if has_videos: + logging.info("Dataset contains video keys - using sequential processing for thread safety") + for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"): + ep_stats = process_single_episode(dataset, episode_idx) + episode_stats_list.append(ep_stats) + else: + logging.info("Dataset has no video keys - using parallel processing for better performance") + max_workers = min(dataset.num_episodes, 16) - episode_results = {} - with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar: - for future in concurrent.futures.as_completed(future_to_episode): - episode_idx = future_to_episode[future] - ep_stats = future.result() - episode_results[episode_idx] = ep_stats - pbar.update(1) + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_episode = { + executor.submit(process_single_episode, dataset, episode_idx): episode_idx + for episode_idx in range(dataset.num_episodes) + } + + episode_results = {} + with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar: + for future in concurrent.futures.as_completed(future_to_episode): + episode_idx = future_to_episode[future] + ep_stats = future.result() + episode_results[episode_idx] = ep_stats + pbar.update(1) for episode_idx in range(dataset.num_episodes): if episode_idx in episode_results: @@ -186,6 +213,14 @@ def augment_dataset_with_quantile_stats( logging.info("Successfully updated dataset with quantile statistics") dataset.push_to_hub() + hub_api = HfApi() + try: + hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") + except HTTPError as e: + logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})") + pass + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset") + def main(): """Main function to run the augmentation script."""