From 3f6dfa4916dcd07bdd5c9a36facbae401624e13c Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 8 Apr 2024 16:18:53 +0200 Subject: [PATCH] Add gym-aloha, rename simxarm -> xarm, refactor --- .github/workflows/test.yml | 4 +- lerobot/__init__.py | 6 +- lerobot/common/datasets/factory.py | 6 +- .../common/datasets/{simxarm.py => xarm.py} | 2 +- lerobot/common/envs/factory.py | 39 ++++-------- lerobot/configs/env/aloha.yaml | 4 +- lerobot/configs/env/pusht.yaml | 4 +- .../configs/env/{simxarm.yaml => xarm.yaml} | 4 +- lerobot/scripts/train.py | 2 +- poetry.lock | 33 +++++++++-- pyproject.toml | 15 +++-- tests/test_available.py | 2 +- tests/test_datasets.py | 2 +- tests/test_envs.py | 59 ++++++------------- tests/test_policies.py | 6 +- 15 files changed, 91 insertions(+), 97 deletions(-) rename lerobot/common/datasets/{simxarm.py => xarm.py} (99%) rename lerobot/configs/env/{simxarm.yaml => xarm.yaml} (82%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 478be771d..c1b14780f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -204,7 +204,7 @@ jobs: source .venv/bin/activate python lerobot/scripts/train.py \ policy=tdmpc \ - env=simxarm \ + env=xarm \ wandb.enable=False \ offline_steps=1 \ online_steps=1 \ @@ -229,6 +229,6 @@ jobs: python lerobot/scripts/eval.py \ --config lerobot/configs/default.yaml \ policy=tdmpc \ - env=simxarm \ + env=xarm \ eval_episodes=1 \ device=cpu diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 5cf8bdb8b..4673aab07 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -27,7 +27,7 @@ from lerobot.__version__ import __version__ # noqa: F401 available_envs = [ "aloha", "pusht", - "simxarm", + "xarm", ] available_tasks_per_env = { @@ -36,7 +36,7 @@ available_tasks_per_env = { "sim_transfer_cube", ], "pusht": ["pusht"], - "simxarm": ["lift"], + "xarm": ["lift"], } available_datasets_per_env = { @@ -47,7 +47,7 @@ available_datasets_per_env = { "aloha_sim_transfer_cube_scripted", ], "pusht": ["pusht"], - "simxarm": ["xarm_lift_medium"], + "xarm": ["xarm_lift_medium"], } available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]] diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index c22ae698b..0dab5d4bf 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -19,10 +19,10 @@ def make_dataset( normalize=True, stats_path=None, ): - if cfg.env.name == "simxarm": - from lerobot.common.datasets.simxarm import SimxarmDataset + if cfg.env.name == "xarm": + from lerobot.common.datasets.xarm import XarmDataset - clsfunc = SimxarmDataset + clsfunc = XarmDataset elif cfg.env.name == "pusht": from lerobot.common.datasets.pusht import PushtDataset diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/xarm.py similarity index 99% rename from lerobot/common/datasets/simxarm.py rename to lerobot/common/datasets/xarm.py index 7bddf608f..733267ab9 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/xarm.py @@ -24,7 +24,7 @@ def download(raw_dir): zip_path.unlink() -class SimxarmDataset(torch.utils.data.Dataset): +class XarmDataset(torch.utils.data.Dataset): available_datasets = [ "xarm_lift_medium", ] diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 9225cbc57..ed5cb926b 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -1,3 +1,5 @@ +import importlib + import gymnasium as gym @@ -8,43 +10,28 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: """ kwargs = { "obs_type": "pixels_agent_pos", + "render_mode": "rgb_array", "max_episode_steps": cfg.env.episode_length, "visualization_width": 384, "visualization_height": 384, } - if cfg.env.name == "simxarm": - import gym_xarm # noqa: F401 + package_name = f"gym_{cfg.env.name}" - assert cfg.env.task == "lift" - env_fn = lambda: gym.make( # noqa: E731 - "gym_xarm/XarmLift-v0", - **kwargs, + try: + importlib.import_module(package_name) + except ModuleNotFoundError as e: + print( + f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`" ) - elif cfg.env.name == "pusht": - import gym_pusht # noqa: F401 + raise e - # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." - env_fn = lambda: gym.make( # noqa: E731 - "gym_pusht/PushTPixels-v0", - **kwargs, - ) - elif cfg.env.name == "aloha": - from lerobot.common.envs import aloha as gym_aloha # noqa: F401 - - kwargs["task"] = cfg.env.task - - env_fn = lambda: gym.make( # noqa: E731 - "gym_aloha/AlohaInsertion-v0", - **kwargs, - ) - else: - raise ValueError(cfg.env.name) + handle = f"{package_name}/{cfg.env.handle}" if num_parallel_envs == 0: # non-batched version of the env that returns an observation of shape (c) - env = env_fn() + env = gym.make(handle, **kwargs) else: # batched version of the env that returns an observation of shape (b, c) - env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)]) + env = gym.vector.SyncVectorEnv([lambda: gym.make(handle, **kwargs) for _ in range(num_parallel_envs)]) return env diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml index 2bfbbaa88..146a45989 100644 --- a/lerobot/configs/env/aloha.yaml +++ b/lerobot/configs/env/aloha.yaml @@ -4,7 +4,7 @@ eval_episodes: 50 eval_freq: 7500 save_freq: 75000 log_freq: 250 -# TODO: same as simxarm, need to adjust +# TODO: same as xarm, need to adjust offline_steps: 25000 online_steps: 25000 @@ -14,6 +14,8 @@ dataset_id: aloha_sim_insertion_human env: name: aloha + handle: AlohaInsertion-v0 + # TODO(aliberts): replace task with handle task: insertion from_pixels: True pixels_only: False diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml index 0050530e1..aafd766a5 100644 --- a/lerobot/configs/env/pusht.yaml +++ b/lerobot/configs/env/pusht.yaml @@ -4,7 +4,7 @@ eval_episodes: 50 eval_freq: 7500 save_freq: 75000 log_freq: 250 -# TODO: same as simxarm, need to adjust +# TODO: same as xarm, need to adjust offline_steps: 25000 online_steps: 25000 @@ -14,6 +14,8 @@ dataset_id: pusht env: name: pusht + handle: PushT-v0 + # TODO(aliberts): replace task with handle task: pusht from_pixels: True pixels_only: False diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/xarm.yaml similarity index 82% rename from lerobot/configs/env/simxarm.yaml rename to lerobot/configs/env/xarm.yaml index 843f80c67..5eb1700e7 100644 --- a/lerobot/configs/env/simxarm.yaml +++ b/lerobot/configs/env/xarm.yaml @@ -12,7 +12,9 @@ fps: 15 dataset_id: xarm_lift_medium env: - name: simxarm + name: xarm + handle: XarmLift-v0 + # TODO(aliberts): replace task with handle task: lift from_pixels: True pixels_only: False diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index cca26902e..3bf09b5fd 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -162,7 +162,7 @@ def train(cfg: dict, out_dir=None, job_name=None): logger = Logger(out_dir, job_name, cfg) log_output_dir(out_dir) - logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.env.handle=}") logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})") logging.info(f"{cfg.online_steps=}") logging.info(f"{cfg.env.action_repeat=}") diff --git a/poetry.lock b/poetry.lock index b9e31930e..98449df4b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -879,10 +879,30 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.62.1)"] +[[package]] +name = "gym-aloha" +version = "0.1.0" +description = "A gym environment for ALOHA" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +dm-control = "1.0.14" +gymnasium = "^0.29.1" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-aloha.git" +reference = "HEAD" +resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11" + [[package]] name = "gym-pusht" version = "0.1.0" -description = "PushT environment for LeRobot" +description = "A gymnasium environment for PushT." optional = true python-versions = "^3.10" files = [] @@ -900,7 +920,7 @@ shapely = "^2.0.3" type = "git" url = "git@github.com:huggingface/gym-pusht.git" reference = "HEAD" -resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec" +resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf" [[package]] name = "gym-xarm" @@ -920,7 +940,7 @@ mujoco = "^2.3.7" type = "git" url = "git@github.com:huggingface/gym-xarm.git" reference = "HEAD" -resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c" +resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d" [[package]] name = "gymnasium" @@ -3630,10 +3650,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -pusht = ["gym_pusht"] -xarm = ["gym_xarm"] +aloha = ["gym-aloha"] +pusht = ["gym-pusht"] +xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6" +content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9" diff --git a/pyproject.toml b/pyproject.toml index b7e1b9fb7..e78a502d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,14 +52,17 @@ robomimic = "0.2.0" gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" -gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} -gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} -# gym_pusht = { path = "../gym-pusht", develop = true, optional = true} -# gym_xarm = { path = "../gym-xarm", develop = true, optional = true} +gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} +gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true} +# gym-pusht = { path = "../gym-pusht", develop = true, optional = true} +# gym-xarm = { path = "../gym-xarm", develop = true, optional = true} +# gym-aloha = { path = "../gym-aloha", develop = true, optional = true} [tool.poetry.extras] -pusht = ["gym_pusht"] -xarm = ["gym_xarm"] +pusht = ["gym-pusht"] +xarm = ["gym-xarm"] +aloha = ["gym-aloha"] [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" diff --git a/tests/test_available.py b/tests/test_available.py index 8a2ece380..8df2c945a 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -19,7 +19,7 @@ import lerobot # from gym_pusht.envs import PushtEnv # from gym_xarm.envs import SimxarmEnv -# from lerobot.common.datasets.simxarm import SimxarmDataset +# from lerobot.common.datasets.xarm import SimxarmDataset # from lerobot.common.datasets.aloha import AlohaDataset # from lerobot.common.datasets.pusht import PushtDataset diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e7777c167..e24d7b4d9 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -11,7 +11,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( "env_name,dataset_id,policy_name", [ - ("simxarm", "xarm_lift_medium", "tdmpc"), + ("xarm", "xarm_lift_medium", "tdmpc"), ("pusht", "pusht", "diffusion"), ("aloha", "aloha_sim_insertion_human", "act"), ("aloha", "aloha_sim_insertion_scripted", "act"), diff --git a/tests/test_envs.py b/tests/test_envs.py index effe4032b..8fcf3a488 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,3 +1,4 @@ +import importlib import pytest import torch from lerobot.common.datasets.factory import make_dataset @@ -13,49 +14,25 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( - "env_task, obs_type", + "env_name, handle, obs_type", [ # ("AlohaInsertion-v0", "state"), - ("AlohaInsertion-v0", "pixels"), - ("AlohaInsertion-v0", "pixels_agent_pos"), - ("AlohaTransferCube-v0", "pixels"), - ("AlohaTransferCube-v0", "pixels_agent_pos"), + ("aloha", "AlohaInsertion-v0", "pixels"), + ("aloha", "AlohaInsertion-v0", "pixels_agent_pos"), + ("aloha", "AlohaTransferCube-v0", "pixels"), + ("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"), + ("xarm", "XarmLift-v0", "state"), + ("xarm", "XarmLift-v0", "pixels"), + ("xarm", "XarmLift-v0", "pixels_agent_pos"), + ("pusht", "PushT-v0", "state"), + ("pusht", "PushT-v0", "pixels"), + ("pusht", "PushT-v0", "pixels_agent_pos"), ], ) -def test_aloha(env_task, obs_type): - from lerobot.common.envs import aloha as gym_aloha # noqa: F401 - env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type) - check_env(env.unwrapped) - - - -@pytest.mark.parametrize( - "env_task, obs_type", - [ - ("XarmLift-v0", "state"), - ("XarmLift-v0", "pixels"), - ("XarmLift-v0", "pixels_agent_pos"), - # TODO(aliberts): Add gym_xarm other tasks - ], -) -def test_xarm(env_task, obs_type): - import gym_xarm # noqa: F401 - env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type) - check_env(env.unwrapped) - - - -@pytest.mark.parametrize( - "env_task, obs_type", - [ - ("PushTPixels-v0", "state"), - ("PushTPixels-v0", "pixels"), - ("PushTPixels-v0", "pixels_agent_pos"), - ], -) -def test_pusht(env_task, obs_type): - import gym_pusht # noqa: F401 - env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type) +def test_env(env_name, handle, obs_type): + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + env = gym.make(f"{package_name}/{handle}", obs_type=obs_type) check_env(env.unwrapped) @@ -63,7 +40,7 @@ def test_pusht(env_task, obs_type): "env_name", [ "pusht", - "simxarm", + "xarm", "aloha", ], ) @@ -76,7 +53,7 @@ def test_factory(env_name): dataset = make_dataset(cfg) env = make_env(cfg, num_parallel_envs=1) - obs, info = env.reset() + obs, _ = env.reset() obs = preprocess_observation(obs, transform=dataset.transform) for key in dataset.image_keys: img = obs[key] diff --git a/tests/test_policies.py b/tests/test_policies.py index c79bff94f..82033b786 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -12,15 +12,15 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ - ("simxarm", "tdmpc", ["policy.mpc=true"]), + ("xarm", "tdmpc", ["policy.mpc=true"]), ("pusht", "tdmpc", ["policy.mpc=false"]), ("pusht", "diffusion", []), # ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]), #("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]), #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]), #("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]), - # TODO(aliberts): simxarm not working with diffusion - # ("simxarm", "diffusion", []), + # TODO(aliberts): xarm not working with diffusion + # ("xarm", "diffusion", []), ], ) def test_policy(env_name, policy_name, extra_overrides):