PushtEnv inheriates AbstractEnv, Improve factory Normalization

This commit is contained in:
Cadene
2024-03-11 14:05:23 +00:00
parent ebd5c786f1
commit bdd2c801bc
3 changed files with 19 additions and 29 deletions

View File

@@ -106,7 +106,9 @@ def make_offline_buffer(
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
transform = NormalizeTransform(stats, in_keys, mode="min_max")
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)
offline_buffer.set_transform(transform)
if not overwrite_sampler: