refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)

* refactor(env): introduce explicit gym ID handling in EnvConfig/factory

This commit introduces properties for the gym package/ID associated
with and environment config. They default to the current defaults
(`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow
for easier use of external gym environments.

Subclasses of `EnvConfig` can override the default properties to allow
the factory to import (i.e. register) the gym env from a specific module,
and also instantiate the env from any ID string.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more changes

* quality

* fix test

---------

Co-authored-by: Ben Sprenger <ben.sprenger@rogers.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Jade Choghari
2025-10-19 20:50:00 +02:00
committed by GitHub
parent a97d078d95
commit a95b15ccc0
3 changed files with 70 additions and 8 deletions

View File

@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str: def type(self) -> str:
return self.get_choice_name(self.__class__) return self.get_choice_name(self.__class__)
@property
def package_name(self) -> str:
"""Package name to import if environment not found in gym registry"""
return f"gym_{self.type}"
@property
def gym_id(self) -> str:
"""ID string used in gym.make() to instantiate the environment"""
return f"{self.package_name}/{self.task}"
@property @property
@abc.abstractmethod @abc.abstractmethod
def gym_kwargs(self) -> dict: def gym_kwargs(self) -> dict:

View File

@@ -16,6 +16,7 @@
import importlib import importlib
import gymnasium as gym import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
@@ -84,17 +85,24 @@ def make_env(
gym_kwargs=cfg.gym_kwargs, gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls, env_cls=env_cls,
) )
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}" if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one(): def _make_one():
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)

View File

@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
from dataclasses import dataclass, field
import gymnasium as gym import gymnasium as gym
import pytest import pytest
import torch import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env from gymnasium.utils.env_checker import check_env
import lerobot import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation from lerobot.envs.utils import preprocess_observation
from tests.utils import require_env from tests.utils import require_env
@@ -64,3 +68,43 @@ def test_factory(env_name):
assert img.min() >= 0.0 assert img.min() >= 0.0
env.close() env.close()
def test_factory_custom_gym_id():
gym_id = "dummy_gym_pkg/DummyTask-v0"
if gym_id in gym_registry:
pytest.skip(f"Environment ID {gym_id} is already registered")
@EnvConfig.register_subclass("dummy")
@dataclass
class DummyEnv(EnvConfig):
task: str = "DummyTask-v0"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self) -> str:
return "dummy_gym_pkg"
@property
def gym_id(self) -> str:
return gym_id
@property
def gym_kwargs(self) -> dict:
return {}
try:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
cfg = DummyEnv()
envs_dict = make_env(cfg, n_envs=1)
dummy_envs = envs_dict["dummy"]
assert len(dummy_envs) == 1
env = next(iter(dummy_envs.values()))
assert env is not None and isinstance(env, gym.vector.VectorEnv)
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]