Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -19,9 +19,6 @@ from math import ceil
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset, num_workers=0):
@@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in dataset.features.items():
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue
for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, (VideoFrame, Image)):
# if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.meta.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
@@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
raise ValueError(f"{key}, {batch[key].shape}")
return stats_patterns
@@ -175,39 +170,45 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
data_keys.update(dataset.meta.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
torch.stack(
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
dim=0,
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
if data_key in d.meta.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
if data_key in d.meta.stats
)
)
return stats