Add Prod transform, Add test_factory

This commit is contained in:
Cadene
2024-02-20 14:22:16 +00:00
parent 3da6ffb2cb
commit 3dc14b5576
5 changed files with 56 additions and 12 deletions

View File

@@ -2,6 +2,7 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv
from lerobot.common.envs.pusht import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv
from lerobot.common.envs.transforms import Prod
def make_env(cfg):
@@ -25,6 +26,10 @@ def make_env(cfg):
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
if cfg.env == "pusht":
# to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
return env