From 31b57de6339416e437dc9927832b9d30a9ad1a60 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 29 May 2024 15:26:56 +0000 Subject: [PATCH] Refactor make_env --- lerobot/common/envs/factory.py | 14 +++++--------- lerobot/configs/env/aloha.yaml | 11 +++++++---- .../configs/env/{dora.yaml => aloha2_real.yaml} | 11 +++++------ lerobot/configs/env/pusht.yaml | 11 +++++++---- lerobot/configs/env/xarm.yaml | 11 +++++++---- 5 files changed, 31 insertions(+), 27 deletions(-) rename lerobot/configs/env/{dora.yaml => aloha2_real.yaml} (54%) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 7932501d1..33742e146 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv if n_envs is not None and n_envs < 1: raise ValueError("`n_envs must be at least 1") - kwargs = { - # "obs_type": "pixels_agent_pos", - # "render_mode": "rgb_array", - "max_episode_steps": cfg.env.episode_length, - # "visualization_width": 384, - # "visualization_height": 384, - } - package_name = f"gym_{cfg.env.name}" try: @@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv raise e gym_handle = f"{package_name}/{cfg.env.task}" + gym_kwgs = cfg.env.get("gym", {}) + + if cfg.env.get("episode_length"): + gym_kwgs["max_episode_steps"] = cfg.env.episode_length # batched version of the env that returns an observation of shape (b, c) env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv env = env_cls( [ - lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs) + lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size) ] ) diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 95e4503d4..d93afba71 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -5,10 +5,13 @@ fps: 50 env: name: aloha task: AlohaInsertion-v0 - from_pixels: True - pixels_only: False image_size: [3, 480, 640] - episode_length: 400 - fps: ${fps} state_dim: 14 action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/configs/env/dora.yaml b/lerobot/configs/env/aloha2_real.yaml similarity index 54% rename from lerobot/configs/env/dora.yaml rename to lerobot/configs/env/aloha2_real.yaml index 3742347fb..3053fc01b 100644 --- a/lerobot/configs/env/dora.yaml +++ b/lerobot/configs/env/aloha2_real.yaml @@ -4,11 +4,10 @@ fps: 30 env: name: dora - task: DoraAloha-v0 - # from_pixels: True - # pixels_only: False - # image_size: [3, 480, 640] - episode_length: 400 - fps: ${fps} + task: DoraAloha2-v0 state_dim: 14 action_dim: 14 + fps: ${fps} + episode_length: 400 + gym: + fps: ${fps} diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 43e9d187c..771fbbf4d 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -5,10 +5,13 @@ fps: 10 env: name: pusht task: PushT-v0 - from_pixels: True - pixels_only: False image_size: 96 - episode_length: 300 - fps: ${fps} state_dim: 2 action_dim: 2 + fps: ${fps} + episode_length: 300 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384 diff --git a/lerobot/configs/env/xarm.yaml b/lerobot/configs/env/xarm.yaml index 098b03962..9dbb96f56 100644 --- a/lerobot/configs/env/xarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -5,10 +5,13 @@ fps: 15 env: name: xarm task: XarmLift-v0 - from_pixels: True - pixels_only: False image_size: 84 - episode_length: 25 - fps: ${fps} state_dim: 4 action_dim: 4 + fps: ${fps} + episode_length: 25 + gym: + obs_type: pixels_agent_pos + render_mode: rgb_array + visualization_width: 384 + visualization_height: 384