Add mode to NormalizeTransform with mean_std or min_max (Not fully tested)
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from lerobot.common.envs.transforms import Prod
|
||||
|
||||
|
||||
def make_env(cfg, transform=None):
|
||||
kwargs = {
|
||||
@@ -28,12 +26,8 @@ def make_env(cfg, transform=None):
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||
|
||||
if cfg.env.name == "pusht":
|
||||
# to ensure pusht is in [0,255] like simxarm
|
||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
||||
|
||||
if transform is not None:
|
||||
# useful to add mean and std normalization
|
||||
# useful to add normalization
|
||||
env.append_transform(transform)
|
||||
|
||||
return env
|
||||
|
||||
Reference in New Issue
Block a user