forked from tangger/lerobot
Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -2,18 +2,13 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
||||
from lerobot.common.transforms import NormalizeTransform
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
||||
normalize=True,
|
||||
stats_path=None,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name == "xarm":
|
||||
@@ -33,58 +28,26 @@ def make_dataset(
|
||||
else:
|
||||
raise ValueError(cfg.env.name)
|
||||
|
||||
transforms = None
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation.state"] = {}
|
||||
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
elif stats_path is None:
|
||||
# load a first dataset to access precomputed stats
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split="train",
|
||||
root=DATA_DIR,
|
||||
)
|
||||
stats = stats_dataset.stats
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
NormalizeTransform(
|
||||
stats,
|
||||
in_keys=[
|
||||
"observation.state",
|
||||
"action",
|
||||
],
|
||||
mode=normalization_mode,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
delta_timestamps = cfg.policy.get("delta_timestamps")
|
||||
if delta_timestamps is not None:
|
||||
for key in delta_timestamps:
|
||||
if isinstance(delta_timestamps[key], str):
|
||||
delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
split=split,
|
||||
root=DATA_DIR,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transforms,
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
Reference in New Issue
Block a user