PushtEnv inheriates AbstractEnv, Improve factory Normalization
This commit is contained in:
@@ -106,7 +106,9 @@ def make_offline_buffer(
|
||||
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||
|
||||
transform = NormalizeTransform(stats, in_keys, mode="min_max")
|
||||
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
||||
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
||||
transform = NormalizeTransform(stats, in_keys, mode=normalization_mode)
|
||||
offline_buffer.set_transform(transform)
|
||||
|
||||
if not overwrite_sampler:
|
||||
|
||||
Reference in New Issue
Block a user