test_examples are passing
This commit is contained in:
@@ -40,7 +40,8 @@ def make_dataset(
|
||||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
||||
# min_max_from_spec
|
||||
# stats = dataset.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
stats = {}
|
||||
@@ -51,7 +52,7 @@ def make_dataset(
|
||||
stats["action"] = {}
|
||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
else:
|
||||
elif stats_path is None:
|
||||
# instantiate a one frame dataset with light transform
|
||||
stats_dataset = clsfunc(
|
||||
dataset_id=cfg.dataset_id,
|
||||
@@ -59,9 +60,8 @@ def make_dataset(
|
||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||
)
|
||||
stats = compute_or_load_stats(stats_dataset)
|
||||
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
else:
|
||||
stats = torch.load(stats_path)
|
||||
|
||||
transforms = v2.Compose(
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user