diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 3aa155093..dc526114d 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): def type(self) -> str: return self.get_choice_name(self.__class__) + @property + def package_name(self) -> str: + """Package name to import if environment not found in gym registry""" + return f"gym_{self.type}" + + @property + def gym_id(self) -> str: + """ID string used in gym.make() to instantiate the environment""" + return f"{self.package_name}/{self.task}" + @property @abc.abstractmethod def gym_kwargs(self) -> dict: diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index 059e0e11a..52c7cbb96 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -16,6 +16,7 @@ import importlib import gymnasium as gym +from gymnasium.envs.registration import registry as gym_registry from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv @@ -84,17 +85,24 @@ def make_env( gym_kwargs=cfg.gym_kwargs, env_cls=env_cls, ) - package_name = f"gym_{cfg.type}" - try: - importlib.import_module(package_name) - except ModuleNotFoundError as e: - print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`") - raise e - gym_handle = f"{package_name}/{cfg.task}" + if cfg.gym_id not in gym_registry: + print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...") + try: + importlib.import_module(cfg.package_name) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. " + f"Please install it or check PYTHONPATH." + ) from e + + if cfg.gym_id not in gym_registry: + raise gym.error.NameNotFound( + f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'." + ) def _make_one(): - return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) + return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 51ea564e5..4c129dbbf 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -14,13 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from dataclasses import dataclass, field import gymnasium as gym import pytest import torch +from gymnasium.envs.registration import register, registry as gym_registry from gymnasium.utils.env_checker import check_env import lerobot +from lerobot.configs.types import PolicyFeature +from lerobot.envs.configs import EnvConfig from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import preprocess_observation from tests.utils import require_env @@ -64,3 +68,43 @@ def test_factory(env_name): assert img.min() >= 0.0 env.close() + + +def test_factory_custom_gym_id(): + gym_id = "dummy_gym_pkg/DummyTask-v0" + if gym_id in gym_registry: + pytest.skip(f"Environment ID {gym_id} is already registered") + + @EnvConfig.register_subclass("dummy") + @dataclass + class DummyEnv(EnvConfig): + task: str = "DummyTask-v0" + fps: int = 10 + features: dict[str, PolicyFeature] = field(default_factory=dict) + + @property + def package_name(self) -> str: + return "dummy_gym_pkg" + + @property + def gym_id(self) -> str: + return gym_id + + @property + def gym_kwargs(self) -> dict: + return {} + + try: + register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv") + + cfg = DummyEnv() + envs_dict = make_env(cfg, n_envs=1) + dummy_envs = envs_dict["dummy"] + assert len(dummy_envs) == 1 + env = next(iter(dummy_envs.values())) + assert env is not None and isinstance(env, gym.vector.VectorEnv) + env.close() + + finally: + if gym_id in gym_registry: + del gym_registry[gym_id]