test_examples are passing

This commit is contained in:
Cadene
2024-04-10 13:45:45 +00:00
parent 6082a7bc73
commit c08003278e
4 changed files with 62 additions and 79 deletions

View File

@@ -40,7 +40,8 @@ def make_dataset(
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
stats = {}
@@ -51,7 +52,7 @@ def make_dataset(
stats["action"] = {}
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
else:
elif stats_path is None:
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
@@ -59,9 +60,8 @@ def make_dataset(
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_or_load_stats(stats_dataset)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
else:
stats = torch.load(stats_path)
transforms = v2.Compose(
[