(WIP) Add gym-xarm

This commit is contained in:
Simon Alibert
2024-04-05 15:35:20 +02:00
parent c17dffe944
commit ab3cd3a7ba
4 changed files with 54 additions and 22 deletions

View File

@@ -3,6 +3,8 @@ from tensordict import TensorDict
import torch
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.datasets.factory import make_dataset
import gymnasium as gym
from gymnasium.utils.env_checker import check_env
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env
@@ -61,29 +63,26 @@ def test_aloha(task, from_pixels, pixels_only):
@pytest.mark.parametrize(
"task,from_pixels,pixels_only",
"task, obs_type",
[
("lift", False, False),
("lift", True, False),
("lift", True, True),
("XarmLift-v0", "state"),
("XarmLift-v0", "pixels"),
("XarmLift-v0", "pixels_agent_pos"),
# TODO(aliberts): Add simxarm other tasks
# ("reach", False, False),
# ("reach", True, False),
# ("push", False, False),
# ("push", True, False),
# ("peg_in_box", False, False),
# ("peg_in_box", True, False),
],
)
def test_simxarm(task, from_pixels, pixels_only):
env = SimxarmEnv(
task,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=84 if from_pixels else None,
)
def test_xarm(env_task, obs_type):
import gym_xarm
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
# env = SimxarmEnv(
# task,
# from_pixels=from_pixels,
# pixels_only=pixels_only,
# image_size=84 if from_pixels else None,
# )
# print_spec_rollout(env)
check_env_specs(env)
# check_env_specs(env)
check_env(env)
@pytest.mark.parametrize(