wrap dm_control aloha into gymnasium (TODO: properly seeding the env)
This commit is contained in:
@@ -41,25 +41,21 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
# print("data from rollout:", simple_rollout(100))
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO")
|
||||
@pytest.mark.parametrize(
|
||||
"task,from_pixels,pixels_only",
|
||||
"env_task, obs_type",
|
||||
[
|
||||
("sim_insertion", True, False),
|
||||
("sim_insertion", True, True),
|
||||
("sim_transfer_cube", True, False),
|
||||
("sim_transfer_cube", True, True),
|
||||
# ("AlohaInsertion-v0", "state"),
|
||||
("AlohaInsertion-v0", "pixels"),
|
||||
("AlohaInsertion-v0", "pixels_agent_pos"),
|
||||
("AlohaTransferCube-v0", "pixels"),
|
||||
("AlohaTransferCube-v0", "pixels_agent_pos"),
|
||||
],
|
||||
)
|
||||
def test_aloha(task, from_pixels, pixels_only):
|
||||
env = AlohaEnv(
|
||||
task,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=[3, 480, 640] if from_pixels else None,
|
||||
)
|
||||
# print_spec_rollout(env)
|
||||
check_env_specs(env)
|
||||
def test_aloha(env_task, obs_type):
|
||||
from lerobot.common.envs import aloha as gym_aloha # noqa: F401
|
||||
env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type)
|
||||
check_env(env)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user