fix bug in augment_dataset_quantile_stats.py that was not detecting… (#2106)
* fix bug in `augment_dataset_quantile_stats.py` that was not detecting the image features because we were looping over hf_dataset. Now we loop over the dataset itself * Update src/lerobot/datasets/v30/augment_dataset_quantile_stats.py Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> --------- Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -40,10 +40,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.datasets.utils import write_stats
|
from lerobot.datasets.utils import write_stats
|
||||||
from lerobot.utils.utils import init_logging
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
@@ -85,13 +87,27 @@ def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
|||||||
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
|
||||||
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
|
||||||
|
|
||||||
|
collected_data: dict[str, list] = {}
|
||||||
|
for idx in range(start_idx, end_idx):
|
||||||
|
item = dataset[idx]
|
||||||
|
for key, value in item.items():
|
||||||
|
if key not in dataset.features:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if key not in collected_data:
|
||||||
|
collected_data[key] = []
|
||||||
|
collected_data[key].append(value)
|
||||||
|
|
||||||
ep_stats = {}
|
ep_stats = {}
|
||||||
for key, data in dataset.hf_dataset[start_idx:end_idx].items():
|
for key, data_list in collected_data.items():
|
||||||
if dataset.features[key]["dtype"] == "string":
|
if dataset.features[key]["dtype"] == "string":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = torch.stack(data).cpu().numpy()
|
data = torch.stack(data_list).cpu().numpy()
|
||||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||||
|
if data.dtype == np.uint8:
|
||||||
|
data = data.astype(np.float32) / 255.0
|
||||||
|
|
||||||
axes_to_reduce = (0, 2, 3)
|
axes_to_reduce = (0, 2, 3)
|
||||||
keepdims = True
|
keepdims = True
|
||||||
else:
|
else:
|
||||||
@@ -103,12 +119,9 @@ def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if dataset.features[key]["dtype"] in ["image", "video"]:
|
if dataset.features[key]["dtype"] in ["image", "video"]:
|
||||||
for k, v in ep_stats[key].items():
|
ep_stats[key] = {
|
||||||
if dataset.features[key]["dtype"] == "video":
|
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
|
||||||
v = v / 255.0
|
}
|
||||||
if k != "count":
|
|
||||||
v = np.squeeze(v, axis=0)
|
|
||||||
ep_stats[key][k] = v
|
|
||||||
|
|
||||||
return ep_stats
|
return ep_stats
|
||||||
|
|
||||||
@@ -121,25 +134,39 @@ def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dic
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary containing aggregated statistics with quantiles
|
Dictionary containing aggregated statistics with quantiles
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Video decoding operations are not thread-safe, so we process episodes sequentially
|
||||||
|
when video keys are present. For datasets without videos, we use parallel processing
|
||||||
|
with ThreadPoolExecutor for better performance.
|
||||||
"""
|
"""
|
||||||
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
|
||||||
|
|
||||||
episode_stats_list = []
|
episode_stats_list = []
|
||||||
max_workers = min(dataset.num_episodes, 16)
|
has_videos = len(dataset.meta.video_keys) > 0
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
if has_videos:
|
||||||
future_to_episode = {
|
logging.info("Dataset contains video keys - using sequential processing for thread safety")
|
||||||
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
for episode_idx in tqdm(range(dataset.num_episodes), desc="Processing episodes"):
|
||||||
for episode_idx in range(dataset.num_episodes)
|
ep_stats = process_single_episode(dataset, episode_idx)
|
||||||
}
|
episode_stats_list.append(ep_stats)
|
||||||
|
else:
|
||||||
|
logging.info("Dataset has no video keys - using parallel processing for better performance")
|
||||||
|
max_workers = min(dataset.num_episodes, 16)
|
||||||
|
|
||||||
episode_results = {}
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
future_to_episode = {
|
||||||
for future in concurrent.futures.as_completed(future_to_episode):
|
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
|
||||||
episode_idx = future_to_episode[future]
|
for episode_idx in range(dataset.num_episodes)
|
||||||
ep_stats = future.result()
|
}
|
||||||
episode_results[episode_idx] = ep_stats
|
|
||||||
pbar.update(1)
|
episode_results = {}
|
||||||
|
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
|
||||||
|
for future in concurrent.futures.as_completed(future_to_episode):
|
||||||
|
episode_idx = future_to_episode[future]
|
||||||
|
ep_stats = future.result()
|
||||||
|
episode_results[episode_idx] = ep_stats
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
for episode_idx in range(dataset.num_episodes):
|
for episode_idx in range(dataset.num_episodes):
|
||||||
if episode_idx in episode_results:
|
if episode_idx in episode_results:
|
||||||
@@ -186,6 +213,14 @@ def augment_dataset_with_quantile_stats(
|
|||||||
logging.info("Successfully updated dataset with quantile statistics")
|
logging.info("Successfully updated dataset with quantile statistics")
|
||||||
dataset.push_to_hub()
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
hub_api = HfApi()
|
||||||
|
try:
|
||||||
|
hub_api.delete_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||||
|
except HTTPError as e:
|
||||||
|
logging.info(f"tag={CODEBASE_VERSION} probably doesn't exist. Skipping exception ({e})")
|
||||||
|
pass
|
||||||
|
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=None, repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main function to run the augmentation script."""
|
"""Main function to run the augmentation script."""
|
||||||
|
|||||||
Reference in New Issue
Block a user