refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)

* refactor(env): introduce explicit gym ID handling in EnvConfig/factory

This commit introduces properties for the gym package/ID associated
with and environment config. They default to the current defaults
(`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow
for easier use of external gym environments.

Subclasses of `EnvConfig` can override the default properties to allow
the factory to import (i.e. register) the gym env from a specific module,
and also instantiate the env from any ID string.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more changes

* quality

* fix test

---------

Co-authored-by: Ben Sprenger <ben.sprenger@rogers.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Jade Choghari
2025-10-19 20:50:00 +02:00
committed by GitHub
parent a97d078d95
commit a95b15ccc0
3 changed files with 70 additions and 8 deletions

View File

@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
def type(self) -> str:
return self.get_choice_name(self.__class__)
@property
def package_name(self) -> str:
"""Package name to import if environment not found in gym registry"""
return f"gym_{self.type}"
@property
def gym_id(self) -> str:
"""ID string used in gym.make() to instantiate the environment"""
return f"{self.package_name}/{self.task}"
@property
@abc.abstractmethod
def gym_kwargs(self) -> dict:

View File

@@ -16,6 +16,7 @@
import importlib
import gymnasium as gym
from gymnasium.envs.registration import registry as gym_registry
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
@@ -84,17 +85,24 @@ def make_env(
gym_kwargs=cfg.gym_kwargs,
env_cls=env_cls,
)
package_name = f"gym_{cfg.type}"
try:
importlib.import_module(package_name)
except ModuleNotFoundError as e:
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
raise e
gym_handle = f"{package_name}/{cfg.task}"
if cfg.gym_id not in gym_registry:
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
try:
importlib.import_module(cfg.package_name)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
f"Please install it or check PYTHONPATH."
) from e
if cfg.gym_id not in gym_registry:
raise gym.error.NameNotFound(
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
)
def _make_one():
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)

View File

@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from dataclasses import dataclass, field
import gymnasium as gym
import pytest
import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation
from tests.utils import require_env
@@ -64,3 +68,43 @@ def test_factory(env_name):
assert img.min() >= 0.0
env.close()
def test_factory_custom_gym_id():
gym_id = "dummy_gym_pkg/DummyTask-v0"
if gym_id in gym_registry:
pytest.skip(f"Environment ID {gym_id} is already registered")
@EnvConfig.register_subclass("dummy")
@dataclass
class DummyEnv(EnvConfig):
task: str = "DummyTask-v0"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self) -> str:
return "dummy_gym_pkg"
@property
def gym_id(self) -> str:
return gym_id
@property
def gym_kwargs(self) -> dict:
return {}
try:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
cfg = DummyEnv()
envs_dict = make_env(cfg, n_envs=1)
dummy_envs = envs_dict["dummy"]
assert len(dummy_envs) == 1
env = next(iter(dummy_envs.values()))
assert env is not None and isinstance(env, gym.vector.VectorEnv)
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]