Refactor make_env
This commit is contained in:
@@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||||||
if n_envs is not None and n_envs < 1:
|
if n_envs is not None and n_envs < 1:
|
||||||
raise ValueError("`n_envs must be at least 1")
|
raise ValueError("`n_envs must be at least 1")
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
# "obs_type": "pixels_agent_pos",
|
|
||||||
# "render_mode": "rgb_array",
|
|
||||||
"max_episode_steps": cfg.env.episode_length,
|
|
||||||
# "visualization_width": 384,
|
|
||||||
# "visualization_height": 384,
|
|
||||||
}
|
|
||||||
|
|
||||||
package_name = f"gym_{cfg.env.name}"
|
package_name = f"gym_{cfg.env.name}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||||
|
gym_kwgs = cfg.env.get("gym", {})
|
||||||
|
|
||||||
|
if cfg.env.get("episode_length"):
|
||||||
|
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||||
|
|
||||||
# batched version of the env that returns an observation of shape (b, c)
|
# batched version of the env that returns an observation of shape (b, c)
|
||||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||||
env = env_cls(
|
env = env_cls(
|
||||||
[
|
[
|
||||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
11
lerobot/configs/env/aloha.yaml
vendored
11
lerobot/configs/env/aloha.yaml
vendored
@@ -5,10 +5,13 @@ fps: 50
|
|||||||
env:
|
env:
|
||||||
name: aloha
|
name: aloha
|
||||||
task: AlohaInsertion-v0
|
task: AlohaInsertion-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: [3, 480, 640]
|
image_size: [3, 480, 640]
|
||||||
episode_length: 400
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 14
|
state_dim: 14
|
||||||
action_dim: 14
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|||||||
@@ -4,11 +4,10 @@ fps: 30
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
name: dora
|
name: dora
|
||||||
task: DoraAloha-v0
|
task: DoraAloha2-v0
|
||||||
# from_pixels: True
|
|
||||||
# pixels_only: False
|
|
||||||
# image_size: [3, 480, 640]
|
|
||||||
episode_length: 400
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 14
|
state_dim: 14
|
||||||
action_dim: 14
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
||||||
11
lerobot/configs/env/pusht.yaml
vendored
11
lerobot/configs/env/pusht.yaml
vendored
@@ -5,10 +5,13 @@ fps: 10
|
|||||||
env:
|
env:
|
||||||
name: pusht
|
name: pusht
|
||||||
task: PushT-v0
|
task: PushT-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 96
|
image_size: 96
|
||||||
episode_length: 300
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 2
|
state_dim: 2
|
||||||
action_dim: 2
|
action_dim: 2
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 300
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|||||||
11
lerobot/configs/env/xarm.yaml
vendored
11
lerobot/configs/env/xarm.yaml
vendored
@@ -5,10 +5,13 @@ fps: 15
|
|||||||
env:
|
env:
|
||||||
name: xarm
|
name: xarm
|
||||||
task: XarmLift-v0
|
task: XarmLift-v0
|
||||||
from_pixels: True
|
|
||||||
pixels_only: False
|
|
||||||
image_size: 84
|
image_size: 84
|
||||||
episode_length: 25
|
|
||||||
fps: ${fps}
|
|
||||||
state_dim: 4
|
state_dim: 4
|
||||||
action_dim: 4
|
action_dim: 4
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 25
|
||||||
|
gym:
|
||||||
|
obs_type: pixels_agent_pos
|
||||||
|
render_mode: rgb_array
|
||||||
|
visualization_width: 384
|
||||||
|
visualization_height: 384
|
||||||
|
|||||||
Reference in New Issue
Block a user