Sanitize cfg.env

This commit is contained in:
Cadene
2024-02-25 12:02:29 +00:00
parent 9b469c4232
commit ed80db2846
6 changed files with 46 additions and 42 deletions

View File

@@ -7,26 +7,26 @@ from lerobot.common.envs.transforms import Prod
def make_env(cfg):
kwargs = {
"frame_skip": cfg.action_repeat,
"from_pixels": cfg.from_pixels,
"pixels_only": cfg.pixels_only,
"image_size": cfg.image_size,
"frame_skip": cfg.env.action_repeat,
"from_pixels": cfg.env.from_pixels,
"pixels_only": cfg.env.pixels_only,
"image_size": cfg.env.image_size,
}
if cfg.env == "simxarm":
kwargs["task"] = cfg.task
if cfg.env.name == "simxarm":
kwargs["task"] = cfg.env.task
clsfunc = SimxarmEnv
elif cfg.env == "pusht":
elif cfg.env.name == "pusht":
clsfunc = PushtEnv
else:
raise ValueError(cfg.env)
raise ValueError(cfg.env.name)
env = clsfunc(**kwargs)
# limit rollout to max_steps
env = TransformedEnv(env, StepCounter(max_steps=cfg.episode_length))
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
if cfg.env == "pusht":
if cfg.env.name == "pusht":
# to ensure pusht is in [0,255] like simxarm
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))