forked from tangger/lerobot
feat(sim): EnvHub - allow loading envs from the hub (#2121)
* add env from the hub support * add safe loading * changes * add tests, docs * more * style/cleaning * order --------- Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
@@ -17,6 +17,7 @@ import importlib
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from gymnasium.envs.registration import register, registry as gym_registry
|
||||
@@ -26,7 +27,11 @@ import lerobot
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.envs.utils import (
|
||||
_normalize_hub_result,
|
||||
_parse_hub_url,
|
||||
preprocess_observation,
|
||||
)
|
||||
from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
@@ -108,3 +113,156 @@ def test_factory_custom_gym_id():
|
||||
finally:
|
||||
if gym_id in gym_registry:
|
||||
del gym_registry[gym_id]
|
||||
|
||||
|
||||
# Hub environment loading tests
|
||||
|
||||
|
||||
def test_make_env_hub_url_parsing():
|
||||
"""Test URL parsing for hub environment references."""
|
||||
# simple repo_id
|
||||
repo_id, revision, file_path = _parse_hub_url("user/repo")
|
||||
assert repo_id == "user/repo"
|
||||
assert revision is None
|
||||
assert file_path == "env.py"
|
||||
|
||||
# repo with revision
|
||||
repo_id, revision, file_path = _parse_hub_url("user/repo@main")
|
||||
assert repo_id == "user/repo"
|
||||
assert revision == "main"
|
||||
assert file_path == "env.py"
|
||||
|
||||
# repo with custom file path
|
||||
repo_id, revision, file_path = _parse_hub_url("user/repo:custom_env.py")
|
||||
assert repo_id == "user/repo"
|
||||
assert revision is None
|
||||
assert file_path == "custom_env.py"
|
||||
|
||||
# repo with revision and custom file path
|
||||
repo_id, revision, file_path = _parse_hub_url("user/repo@v1.0:envs/my_env.py")
|
||||
assert repo_id == "user/repo"
|
||||
assert revision == "v1.0"
|
||||
assert file_path == "envs/my_env.py"
|
||||
|
||||
# repo with commit hash
|
||||
repo_id, revision, file_path = _parse_hub_url("org/repo@abc123def456")
|
||||
assert repo_id == "org/repo"
|
||||
assert revision == "abc123def456"
|
||||
assert file_path == "env.py"
|
||||
|
||||
|
||||
def test_normalize_hub_result():
|
||||
"""Test normalization of different return types from hub make_env."""
|
||||
# test with VectorEnv (most common case)
|
||||
mock_vec_env = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||
result = _normalize_hub_result(mock_vec_env)
|
||||
assert isinstance(result, dict)
|
||||
assert len(result) == 1
|
||||
suite_name = next(iter(result))
|
||||
assert 0 in result[suite_name]
|
||||
assert isinstance(result[suite_name][0], gym.vector.VectorEnv)
|
||||
mock_vec_env.close()
|
||||
|
||||
# test with single Env
|
||||
mock_env = gym.make("CartPole-v1")
|
||||
result = _normalize_hub_result(mock_env)
|
||||
assert isinstance(result, dict)
|
||||
suite_name = next(iter(result))
|
||||
assert 0 in result[suite_name]
|
||||
assert isinstance(result[suite_name][0], gym.vector.VectorEnv)
|
||||
result[suite_name][0].close()
|
||||
|
||||
# test with dict (already normalized)
|
||||
mock_vec_env = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||
input_dict = {"my_suite": {0: mock_vec_env}}
|
||||
result = _normalize_hub_result(input_dict)
|
||||
assert result == input_dict
|
||||
assert "my_suite" in result
|
||||
assert 0 in result["my_suite"]
|
||||
mock_vec_env.close()
|
||||
|
||||
# test with invalid type
|
||||
with pytest.raises(ValueError, match="Hub `make_env` must return"):
|
||||
_normalize_hub_result("invalid_type")
|
||||
|
||||
|
||||
def test_make_env_from_hub_requires_trust_remote_code():
|
||||
"""Test that loading from hub requires explicit trust_remote_code=True."""
|
||||
hub_id = "lerobot/cartpole-env"
|
||||
|
||||
# Should raise RuntimeError when trust_remote_code=False (default)
|
||||
with pytest.raises(RuntimeError, match="Refusing to execute remote code"):
|
||||
make_env(hub_id, trust_remote_code=False)
|
||||
|
||||
# Should also raise when not specified (defaults to False)
|
||||
with pytest.raises(RuntimeError, match="Refusing to execute remote code"):
|
||||
make_env(hub_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hub_id",
|
||||
[
|
||||
"lerobot/cartpole-env",
|
||||
"lerobot/cartpole-env@main",
|
||||
"lerobot/cartpole-env:env.py",
|
||||
],
|
||||
)
|
||||
def test_make_env_from_hub_with_trust(hub_id):
|
||||
"""Test loading environment from Hugging Face Hub with trust_remote_code=True."""
|
||||
# load environment from hub
|
||||
envs_dict = make_env(hub_id, n_envs=2, trust_remote_code=True)
|
||||
|
||||
# verify structure
|
||||
assert isinstance(envs_dict, dict)
|
||||
assert len(envs_dict) >= 1
|
||||
|
||||
# get the first suite and task
|
||||
suite_name = next(iter(envs_dict))
|
||||
task_id = next(iter(envs_dict[suite_name]))
|
||||
env = envs_dict[suite_name][task_id]
|
||||
|
||||
# verify it's a vector environment
|
||||
assert isinstance(env, gym.vector.VectorEnv)
|
||||
assert env.num_envs == 2
|
||||
|
||||
# test basic environment interaction
|
||||
obs, info = env.reset()
|
||||
assert obs is not None
|
||||
assert isinstance(obs, (dict, np.ndarray))
|
||||
|
||||
# take a random action
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
assert obs is not None
|
||||
assert isinstance(reward, np.ndarray)
|
||||
assert len(reward) == 2
|
||||
|
||||
# clean up
|
||||
env.close()
|
||||
|
||||
|
||||
def test_make_env_from_hub_async():
|
||||
"""Test loading hub environment with async vector environments."""
|
||||
hub_id = "lerobot/cartpole-env"
|
||||
|
||||
# load with async envs
|
||||
envs_dict = make_env(hub_id, n_envs=2, use_async_envs=True, trust_remote_code=True)
|
||||
|
||||
suite_name = next(iter(envs_dict))
|
||||
task_id = next(iter(envs_dict[suite_name]))
|
||||
env = envs_dict[suite_name][task_id]
|
||||
|
||||
# verify it's an async vector environment
|
||||
assert isinstance(env, gym.vector.AsyncVectorEnv)
|
||||
assert env.num_envs == 2
|
||||
|
||||
# test basic interaction
|
||||
obs, info = env.reset()
|
||||
assert obs is not None
|
||||
|
||||
action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
assert len(reward) == 2
|
||||
|
||||
# clean up
|
||||
env.close()
|
||||
|
||||
Reference in New Issue
Block a user