enable test_compute_stats

enable test_compute_stats
This commit is contained in:
Cadene
2024-04-10 17:10:46 +00:00
parent 4c3d8b061e
commit 9874652c2f
3 changed files with 83 additions and 40 deletions

View File

@@ -1,10 +1,11 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_or_load_stats
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
@@ -59,7 +60,15 @@ def make_dataset(
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# load stats if the file exists already or compute stats and save it
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
stats = compute_stats(stats_dataset)
torch.save(stats, stats_path)
else:
stats = torch.load(stats_path)

View File

@@ -1,5 +1,4 @@
import io
import logging
import zipfile
from copy import deepcopy
from math import ceil
@@ -103,13 +102,18 @@ def load_data_with_delta_timestamps(
return data, is_pad
def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
stats_path = dataset.data_dir / "stats.pth"
if stats_path.exists():
return torch.load(stats_path)
def get_stats_einops_patterns(dataset):
"""These einops patterns will be used to aggregate batches and compute statistics."""
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
return stats_patterns
logging.info(f"compute_stats and save to {stats_path}")
def compute_stats(dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
else:
@@ -124,13 +128,8 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
drop_last=False,
)
# these einops patterns will be used to aggregate batches and compute statistics
stats_patterns = {
"action": "b c -> c",
"observation.state": "b c -> c",
}
for key in dataset.image_keys:
stats_patterns[key] = "b c h w -> c 1 1"
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
@@ -201,7 +200,6 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
"min": min[key],
}
torch.save(stats, stats_path)
return stats