refactor(env): introduce explicit gym ID handling in EnvConfig/factory (#2234)
* refactor(env): introduce explicit gym ID handling in EnvConfig/factory
This commit introduces properties for the gym package/ID associated
with and environment config. They default to the current defaults
(`gym_{package_name}/{task_id}`) to avoid breaking changes, but allow
for easier use of external gym environments.
Subclasses of `EnvConfig` can override the default properties to allow
the factory to import (i.e. register) the gym env from a specific module,
and also instantiate the env from any ID string.
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* more changes
* quality
* fix test
---------
Co-authored-by: Ben Sprenger <ben.sprenger@rogers.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -14,13 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
import torch
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from tests.utils import require_env
|
||||
@@ -64,3 +68,43 @@ def test_factory(env_name):
|
||||
assert img.min() >= 0.0
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test_factory_custom_gym_id():
|
||||
gym_id = "dummy_gym_pkg/DummyTask-v0"
|
||||
if gym_id in gym_registry:
|
||||
pytest.skip(f"Environment ID {gym_id} is already registered")
|
||||
|
||||
@EnvConfig.register_subclass("dummy")
|
||||
@dataclass
|
||||
class DummyEnv(EnvConfig):
|
||||
task: str = "DummyTask-v0"
|
||||
fps: int = 10
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def package_name(self) -> str:
|
||||
return "dummy_gym_pkg"
|
||||
|
||||
@property
|
||||
def gym_id(self) -> str:
|
||||
return gym_id
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
try:
|
||||
register(id=gym_id, entry_point="gymnasium.envs.classic_control:CartPoleEnv")
|
||||
|
||||
cfg = DummyEnv()
|
||||
envs_dict = make_env(cfg, n_envs=1)
|
||||
dummy_envs = envs_dict["dummy"]
|
||||
assert len(dummy_envs) == 1
|
||||
env = next(iter(dummy_envs.values()))
|
||||
assert env is not None and isinstance(env, gym.vector.VectorEnv)
|
||||
env.close()
|
||||
|
||||
finally:
|
||||
if gym_id in gym_registry:
|
||||
del gym_registry[gym_id]
|
||||
|
||||
Reference in New Issue
Block a user