WIP stats (TODO: run tests on stats + cmpute them)

This commit is contained in:
Cadene
2024-04-04 16:36:03 +00:00
parent 1cdfbc8b52
commit c93ce35d8c
5 changed files with 157 additions and 286 deletions

View File

@@ -4,7 +4,8 @@ from pathlib import Path
import torch
from torchvision.transforms import v2
from lerobot.common.transforms import Prod
from lerobot.common.datasets.utils import compute_or_load_stats
from lerobot.common.transforms import NormalizeTransform, Prod
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
@@ -41,9 +42,8 @@ def make_dataset(
# min_max_from_spec
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
stats = {}
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation.state"] = {}
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
@@ -51,22 +51,30 @@ def make_dataset(
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
else:
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
root=DATA_DIR,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
# normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms = v2.Compose(
[
# TODO(rcadene): we need to do something about image_keys
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
# NormalizeTransform(
# stats,
# in_keys=[
# "observation.state",
# "action",
# ],
# mode=normalization_mode,
# ),
NormalizeTransform(
stats,
in_keys=[
"observation.state",
"action",
],
mode=normalization_mode,
),
]
)