enable test_compute_stats

enable test_compute_stats
This commit is contained in:
Cadene
2024-04-10 17:10:46 +00:00
parent 4c3d8b061e
commit 9874652c2f
3 changed files with 83 additions and 40 deletions

View File

@@ -1,10 +1,11 @@
import logging
import os
from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.datasets.utils import compute_or_load_stats
from lerobot.common.datasets.utils import compute_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
@@ -59,7 +60,15 @@ def make_dataset(
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# load stats if the file exists already or compute stats and save it
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
stats = compute_stats(stats_dataset)
torch.save(stats, stats_path)
else:
stats = torch.load(stats_path)