forked from tangger/lerobot
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -45,12 +45,20 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
assert (
|
||||
batch[key].dtype == torch.float32
|
||||
), f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert (
|
||||
batch[key].max() <= 1
|
||||
), f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert (
|
||||
batch[key].min() >= 0
|
||||
), f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
@@ -98,7 +106,11 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
tqdm.tqdm(
|
||||
dataloader,
|
||||
total=ceil(max_num_samples / batch_size),
|
||||
desc="Compute mean, min, max",
|
||||
)
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
@@ -113,9 +125,16 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
mean[key] = (
|
||||
mean[key]
|
||||
+ this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
)
|
||||
max[key] = torch.maximum(
|
||||
max[key], einops.reduce(batch[key], pattern, "max")
|
||||
)
|
||||
min[key] = torch.minimum(
|
||||
min[key], einops.reduce(batch[key], pattern, "min")
|
||||
)
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
@@ -124,7 +143,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
tqdm.tqdm(
|
||||
dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std"
|
||||
)
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
@@ -138,7 +159,9 @@ def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
std[key] = (
|
||||
std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
)
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
@@ -177,13 +200,19 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack(
|
||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||
[
|
||||
ds.meta.stats[data_key][stat_key]
|
||||
for ds in ls_datasets
|
||||
if data_key in ds.meta.stats
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||
total_samples = sum(
|
||||
d.num_frames for d in ls_datasets if data_key in d.meta.stats
|
||||
)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
|
||||
Reference in New Issue
Block a user