Add mode to NormalizeTransform with mean_std or min_max (Not fully tested)

This commit is contained in:
Remi Cadene
2024-03-03 13:19:02 +00:00
parent 48ded3dbc7
commit cbbed590a9
4 changed files with 75 additions and 33 deletions

View File

@@ -1,7 +1,5 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.transforms import Prod
def make_env(cfg, transform=None):
kwargs = {
@@ -28,12 +26,8 @@ def make_env(cfg, transform=None):
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if cfg.env.name == "pusht":
# to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
if transform is not None:
# useful to add mean and std normalization
# useful to add normalization
env.append_transform(transform)
return env