refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)

* refactor(env): introduce explicit gym ID handling in EnvConfig/factory

This commit introduces properties for the gym package/ID associated
with and environment config. They default to the current defaults
(`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow
for easier use of external gym environments.

Subclasses of `EnvConfig` can override the default properties to allow
the factory to import (i.e. register) the gym env from a specific module,
and also instantiate the env from any ID string.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* more changes

* quality

* fix test

---------

Co-authored-by: Ben Sprenger <ben.sprenger@rogers.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Jade Choghari
2025-10-19 20:50:00 +02:00
committed by GitHub
parent a97d078d95
commit a95b15ccc0
3 changed files with 70 additions and 8 deletions

View File

@@ -14,13 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from dataclasses import dataclass, field
import gymnasium as gym
import pytest
import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config
from lerobot.envs.utils import preprocess_observation
from tests.utils import require_env
@@ -64,3 +68,43 @@ def test_factory(env_name):
assert img.min() >= 0.0
env.close()
def test_factory_custom_gym_id():
gym_id = "dummy_gym_pkg/DummyTask-v0"
if gym_id in gym_registry:
pytest.skip(f"Environment ID {gym_id} is already registered")
@EnvConfig.register_subclass("dummy")
@dataclass
class DummyEnv(EnvConfig):
task: str = "DummyTask-v0"
fps: int = 10
features: dict[str, PolicyFeature] = field(default_factory=dict)
@property
def package_name(self) -> str:
return "dummy_gym_pkg"
@property
def gym_id(self) -> str:
return gym_id
@property
def gym_kwargs(self) -> dict:
return {}
try:
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
cfg = DummyEnv()
envs_dict = make_env(cfg, n_envs=1)
dummy_envs = envs_dict["dummy"]
assert len(dummy_envs) == 1
env = next(iter(dummy_envs.values()))
assert env is not None and isinstance(env, gym.vector.VectorEnv)
env.close()
finally:
if gym_id in gym_registry:
del gym_registry[gym_id]