enable test_compute_stats

enable test_compute_stats
This commit is contained in:
Cadene
2024-04-10 17:10:46 +00:00
parent 4c3d8b061e
commit 9874652c2f
3 changed files with 83 additions and 40 deletions

View File

@@ -1,5 +1,4 @@
import io
import logging
import zipfile
from copy import deepcopy
from math import ceil
@@ -103,13 +102,18 @@ def load_data_with_delta_timestamps(
return data, is_pad
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
stats_path = dataset.data_dir / "stats.pth"
if stats_path.exists():
return torch.load(stats_path)
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
return stats_patterns
logging.info(f"compute_stats and save to {stats_path}")
def compute_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
@@ -124,13 +128,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
drop_last=False,
)
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@@ -201,7 +200,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
"min": min[key],
}
torch.save(stats, stats_path)
return stats