Add Prod transform, Add test_factory
This commit is contained in:
@@ -2,6 +2,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||
|
||||
from lerobot.common.envs.pusht import PushtEnv
|
||||
from lerobot.common.envs.simxarm import SimxarmEnv
|
||||
from lerobot.common.envs.transforms import Prod
|
||||
|
||||
|
||||
def make_env(cfg):
|
||||
@@ -25,6 +26,10 @@ def make_env(cfg):
|
||||
# limit rollout to max_steps
|
||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
|
||||
|
||||
if cfg.env == "pusht":
|
||||
# to ensure pusht is in [0,255] like simxarm
|
||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
||||
|
||||
return env
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user