forked from tangger/lerobot
refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)
* refactor(env): introduce explicit gym ID handling in EnvConfig/factory
This commit introduces properties for the gym package/ID associated
with and environment config. They default to the current defaults
(`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow
for easier use of external gym environments.
Subclasses of `EnvConfig` can override the default properties to allow
the factory to import (i.e. register) the gym env from a specific module,
and also instantiate the env from any ID string.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* more changes
* quality
* fix test
---------
Co-authored-by: Ben Sprenger <ben.sprenger@rogers.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -37,6 +37,16 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
def package_name(self) -> str:
|
||||
"""Package name to import if environment not found in gym registry"""
|
||||
return f"gym_{self.type}"
|
||||
|
||||
@property
|
||||
def gym_id(self) -> str:
|
||||
"""ID string used in gym.make() to instantiate the environment"""
|
||||
return f"{self.package_name}/{self.task}"
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
|
||||
@@ -84,17 +85,24 @@ def make_env(
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
package_name = f"gym_{cfg.type}"
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
if cfg.gym_id not in gym_registry:
|
||||
print(f"gym id '{cfg.gym_id}' not found, attempting to import '{cfg.package_name}'...")
|
||||
try:
|
||||
importlib.import_module(cfg.package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Package '{cfg.package_name}' required for env '{cfg.type}' not found. "
|
||||
f"Please install it or check PYTHONPATH."
|
||||
) from e
|
||||
|
||||
if cfg.gym_id not in gym_registry:
|
||||
raise gym.error.NameNotFound(
|
||||
f"Environment '{cfg.gym_id}' not registered even after importing '{cfg.package_name}'."
|
||||
)
|
||||
|
||||
def _make_one():
|
||||
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
return gym.make(cfg.gym_id, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||
|
||||
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||
|
||||
|
||||
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
import torch
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from tests.utils import require_env
|
||||
@@ -64,3 +68,43 @@ def test_factory(env_name):
|
||||
assert img.min() >= 0.0
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test_factory_custom_gym_id():
|
||||
gym_id = "dummy_gym_pkg/DummyTask-v0"
|
||||
if gym_id in gym_registry:
|
||||
pytest.skip(f"Environment ID {gym_id} is already registered")
|
||||
|
||||
@EnvConfig.register_subclass("dummy")
|
||||
@dataclass
|
||||
class DummyEnv(EnvConfig):
|
||||
task: str = "DummyTask-v0"
|
||||
fps: int = 10
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def package_name(self) -> str:
|
||||
return "dummy_gym_pkg"
|
||||
|
||||
@property
|
||||
def gym_id(self) -> str:
|
||||
return gym_id
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
try:
|
||||
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
||||
|
||||
cfg = DummyEnv()
|
||||
envs_dict = make_env(cfg, n_envs=1)
|
||||
dummy_envs = envs_dict["dummy"]
|
||||
assert len(dummy_envs) == 1
|
||||
env = next(iter(dummy_envs.values()))
|
||||
assert env is not None and isinstance(env, gym.vector.VectorEnv)
|
||||
env.close()
|
||||
|
||||
finally:
|
||||
if gym_id in gym_registry:
|
||||
del gym_registry[gym_id]
|
||||
|
||||
Reference in New Issue
Block a user