enable test_compute_stats
enable test_compute_stats
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import io
|
||||
import logging
|
||||
import zipfile
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
@@ -103,13 +102,18 @@ def load_data_with_delta_timestamps(
|
||||
return data, is_pad
|
||||
|
||||
|
||||
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
stats_path = dataset.data_dir / "stats.pth"
|
||||
if stats_path.exists():
|
||||
return torch.load(stats_path)
|
||||
def get_stats_einops_patterns(dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics."""
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
return stats_patterns
|
||||
|
||||
logging.info(f"compute_stats and save to {stats_path}")
|
||||
|
||||
def compute_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
else:
|
||||
@@ -124,13 +128,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# these einops patterns will be used to aggregate batches and compute statistics
|
||||
stats_patterns = {
|
||||
"action": "b c -> c",
|
||||
"observation.state": "b c -> c",
|
||||
}
|
||||
for key in dataset.image_keys:
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
@@ -201,7 +200,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||
"min": min[key],
|
||||
}
|
||||
|
||||
torch.save(stats, stats_path)
|
||||
return stats
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user