From 271d92dcaae084a0ab48763e7a2efd37b9a27fe6 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 14 Oct 2025 17:21:18 +0200 Subject: [PATCH] feat(sim): add metaworld env (#2088) * add metaworld * smol update Signed-off-by: Jade Choghari * update design * Update src/lerobot/envs/metaworld.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari * update * small changes * iterate on review * small fix * small fix * add docs * update doc * add better gif * smol doc fix * updage gymnasium * add note * depreciate gym-xarm * more changes * update doc * comply with mypy * more fixes * update readme * precommit * update pusht * add pusht instead * changes * style * add changes * update * revert * update v2 * chore(envs): move metaworld config to its own file + remove comments + simplify _format_raw_obs (#2200) * update final changes --------- Signed-off-by: Jade Choghari Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma --- CONTRIBUTING.md | 1 - Makefile | 10 +- docs/source/_toctree.yml | 10 +- docs/source/installation.mdx | 2 +- docs/source/libero.mdx | 2 +- docs/source/metaworld.mdx | 80 +++++++ pyproject.toml | 11 +- src/lerobot/__init__.py | 12 - src/lerobot/envs/__init__.py | 2 +- src/lerobot/envs/configs.py | 81 ++++--- src/lerobot/envs/factory.py | 17 +- src/lerobot/envs/libero.py | 26 +- src/lerobot/envs/metaworld.py | 313 +++++++++++++++++++++++++ src/lerobot/envs/metaworld_config.json | 121 ++++++++++ src/lerobot/scripts/lerobot_eval.py | 10 +- tests/policies/test_policies.py | 3 - 16 files changed, 612 insertions(+), 89 deletions(-) create mode 100644 docs/source/metaworld.mdx create mode 100644 src/lerobot/envs/metaworld.py create mode 100644 src/lerobot/envs/metaworld_config.json diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 369af602b..a07596728 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -72,7 +72,6 @@ post it. Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/), environments ([aloha](https://github.com/huggingface/gym-aloha), -[xarm](https://github.com/huggingface/gym-xarm), [pusht](https://github.com/huggingface/gym-pusht)) and follow the same api design. diff --git a/Makefile b/Makefile index fbe8a5bae..e02f02403 100644 --- a/Makefile +++ b/Makefile @@ -119,10 +119,9 @@ test-tdmpc-ete-train: --policy.type=tdmpc \ --policy.device=$(DEVICE) \ --policy.push_to_hub=false \ - --env.type=xarm \ - --env.task=XarmLift-v0 \ + --env.type=pusht \ --env.episode_length=5 \ - --dataset.repo_id=lerobot/xarm_lift_medium \ + --dataset.repo_id=lerobot/pusht_image \ --dataset.image_transforms.enable=true \ --dataset.episodes="[0]" \ --batch_size=2 \ @@ -140,9 +139,10 @@ test-tdmpc-ete-eval: lerobot-eval \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ - --env.type=xarm \ + --env.type=pusht \ --env.episode_length=5 \ - --env.task=XarmLift-v0 \ + --env.observation_height=96 \ + --env.observation_width=96 \ --eval.n_episodes=1 \ --eval.batch_size=1 diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 568bd6380..b7e71e010 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,8 +7,6 @@ - sections: - local: il_robots title: Imitation Learning for Robots - - local: il_sim - title: Imitation Learning in Sim - local: cameras title: Cameras - local: integrate_hardware @@ -37,9 +35,15 @@ title: π₀ (Pi0) - local: pi05 title: π₀.₅ (Pi05) + title: "Policies" +- sections: + - local: il_sim + title: Imitation Learning in Sim - local: libero title: Using Libero - title: "Policies" + - local: metaworld + title: Using MetaWorld + title: "Simulation" - sections: - local: introduction_processors title: Introduction to Robot Processors diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 93354c2ee..f5fd09acd 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -91,7 +91,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c ### Simulations -Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) +Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) Example: ```bash diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 3f2b92406..14f51ef3b 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -137,7 +137,7 @@ The finetuned model can be found here: We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command: ```bash -python src/lerobot/scripts/eval.py \ +lerobot-eval \ --output_dir=/logs/ \ --env.type=libero \ --env.task=libero_spatial,libero_object,libero_goal,libero_10 \ diff --git a/docs/source/metaworld.mdx b/docs/source/metaworld.mdx new file mode 100644 index 000000000..da90bd51d --- /dev/null +++ b/docs/source/metaworld.mdx @@ -0,0 +1,80 @@ +# Meta-World + +Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics. + +- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897) +- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld) + +![MetaWorld MT10 demo](https://meta-world.github.io/figures/ml45.gif) + +## Why Meta-World matters + +- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure. +- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task. +- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes. +- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions. + +## What it enables in LeRobot + +In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward: + +- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`. + - This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting). + - MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency. + +- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals. + +## Quick start, train a SmolVLA policy on Meta-World + +Example command to train a SmolVLA policy on a subset of tasks: + +```bash +lerobot-train \ + --policy.type=smolvla \ + --policy.repo_id=${HF_USER}/metaworld-test \ + --policy.load_vlm_weights=true \ + --dataset.repo_id=lerobot/metaworld_mt50 \ + --env.type=metaworld \ + --env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \ + --output_dir=./outputs/ \ + --steps=100000 \ + --batch_size=4 \ + --eval.batch_size=1 \ + --eval.n_episodes=1 \ + --eval_freq=1000 +``` + +Notes: + +- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`). +- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget. +- **Gymnasium Assertion Error**: if you encounter an error like + `AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version. + We recommend using: + +```bash + pip install "gymnasium==1.1.0" +``` + +to ensure proper compatibility. + +## Quick start — evaluate a trained policy + +To evaluate a trained policy on the Meta-World medium difficulty split: + +```bash +lerobot-eval \ + --policy.path="your-policy-id" \ + --env.type=metaworld \ + --env.task=medium \ + --eval.batch_size=1 \ + --eval.n_episodes=2 +``` + +This will run episodes and return per-task success rates using the standard Meta-World evaluation keys. + +## Practical tips + +- If you care about generalization, run on the full MT50 suite — it’s intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks. +- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context. +- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark. diff --git a/pyproject.toml b/pyproject.toml index 44ca596b1..6d43c33df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ dependencies = [ "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency "draccus==0.10.0", # TODO: Remove == - "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency + "gymnasium>=1.0.0", "rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency # Support dependencies @@ -133,11 +133,10 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0 video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"] # Simulation -aloha = ["gym-aloha>=0.1.1,<0.2.0"] +aloha = ["gym-aloha>=0.1.2,<0.2.0"] pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead -xarm = ["gym-xarm>=0.1.1,<0.2.0"] -libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"] - +libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@upgrade-dep#egg=libero"] +metaworld = ["metaworld>=3.0.0"] # All all = [ @@ -157,9 +156,9 @@ all = [ "lerobot[video_benchmark]", "lerobot[aloha]", "lerobot[pusht]", - "lerobot[xarm]", "lerobot[phone]", "lerobot[libero]", + "lerobot[metaworld]", ] [project.scripts] diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 9d3ed1893..eec574296 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -57,7 +57,6 @@ available_tasks_per_env = { "AlohaTransferCube-v0", ], "pusht": ["PushT-v0"], - "xarm": ["XarmLift-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -75,16 +74,6 @@ available_datasets_per_env = { # TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly # coupled with tests. "pusht": ["lerobot/pusht", "lerobot/pusht_image"], - "xarm": [ - "lerobot/xarm_lift_medium", - "lerobot/xarm_lift_medium_replay", - "lerobot/xarm_push_medium", - "lerobot/xarm_push_medium_replay", - "lerobot/xarm_lift_medium_image", - "lerobot/xarm_lift_medium_replay_image", - "lerobot/xarm_push_medium_image", - "lerobot/xarm_push_medium_replay_image", - ], } available_real_world_datasets = [ @@ -195,7 +184,6 @@ available_motors = [ available_policies_per_env = { "aloha": ["act"], "pusht": ["diffusion", "vqbet"], - "xarm": ["tdmpc"], "koch_real": ["act_koch_real"], "aloha_real": ["act_aloha_real"], } diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py index 4977d11d9..d767b6e8c 100644 --- a/src/lerobot/envs/__init__.py +++ b/src/lerobot/envs/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401 +from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401 diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 7a979b864..3aa155093 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -133,45 +133,6 @@ class PushtEnv(EnvConfig): } -@EnvConfig.register_subclass("xarm") -@dataclass -class XarmEnv(EnvConfig): - task: str | None = "XarmLift-v0" - fps: int = 15 - episode_length: int = 200 - obs_type: str = "pixels_agent_pos" - render_mode: str = "rgb_array" - visualization_width: int = 384 - visualization_height: int = 384 - features: dict[str, PolicyFeature] = field( - default_factory=lambda: { - ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)), - "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), - } - ) - features_map: dict[str, str] = field( - default_factory=lambda: { - ACTION: ACTION, - "agent_pos": OBS_STATE, - "pixels": OBS_IMAGE, - } - ) - - def __post_init__(self): - if self.obs_type == "pixels_agent_pos": - self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) - - @property - def gym_kwargs(self) -> dict: - return { - "obs_type": self.obs_type, - "render_mode": self.render_mode, - "visualization_width": self.visualization_width, - "visualization_height": self.visualization_height, - "max_episode_steps": self.episode_length, - } - - @dataclass class ImagePreprocessingConfig: crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None @@ -306,3 +267,45 @@ class LiberoEnv(EnvConfig): "obs_type": self.obs_type, "render_mode": self.render_mode, } + + +@EnvConfig.register_subclass("metaworld") +@dataclass +class MetaworldEnv(EnvConfig): + task: str = "metaworld-push-v2" # add all tasks + fps: int = 80 + episode_length: int = 400 + obs_type: str = "pixels_agent_pos" + render_mode: str = "rgb_array" + multitask_eval: bool = True + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "agent_pos": OBS_STATE, + "top": f"{OBS_IMAGE}", + "pixels/top": f"{OBS_IMAGE}", + } + ) + + def __post_init__(self): + if self.obs_type == "pixels": + self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3)) + + elif self.obs_type == "pixels_agent_pos": + self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) + self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3)) + + else: + raise ValueError(f"Unsupported obs_type: {self.obs_type}") + + @property + def gym_kwargs(self) -> dict: + return { + "obs_type": self.obs_type, + "render_mode": self.render_mode, + } diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index c27f01b65..059e0e11a 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -25,8 +25,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return AlohaEnv(**kwargs) elif env_type == "pusht": return PushtEnv(**kwargs) - elif env_type == "xarm": - return XarmEnv(**kwargs) elif env_type == "libero": return LiberoEnv(**kwargs) else: @@ -74,7 +72,18 @@ def make_env( gym_kwargs=cfg.gym_kwargs, env_cls=env_cls, ) + elif "metaworld" in cfg.type: + from lerobot.envs.metaworld import create_metaworld_envs + if cfg.task is None: + raise ValueError("MetaWorld requires a task to be specified") + + return create_metaworld_envs( + task=cfg.task, + n_envs=n_envs, + gym_kwargs=cfg.gym_kwargs, + env_cls=env_cls, + ) package_name = f"gym_{cfg.type}" try: importlib.import_module(package_name) @@ -87,7 +96,7 @@ def make_env( def _make_one(): return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {})) - vec = env_cls([_make_one for _ in range(n_envs)]) + vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP) # normalize to {suite: {task_id: vec_env}} for consistency suite_name = cfg.type # e.g., "pusht", "aloha" diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 99ec6712f..94b08e991 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -260,19 +260,23 @@ class LiberoEnv(gym.Env): is_success = self._env.check_success() terminated = done or is_success - info["is_success"] = is_success - + info.update( + { + "task": self.task, + "task_id": self.task_id, + "done": done, + "is_success": is_success, + } + ) observation = self._format_raw_obs(raw_obs) - if done: + if terminated: + info["final_info"] = { + "task": self.task, + "task_id": self.task_id, + "done": bool(done), + "is_success": bool(is_success), + } self.reset() - info.update( - { - "task": self.task, - "task_id": self.task_id, - "done": done, - "is_success": is_success, - } - ) truncated = False return observation, reward, terminated, truncated, info diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py new file mode 100644 index 000000000..9190f33ad --- /dev/null +++ b/src/lerobot/envs/metaworld.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from collections import defaultdict +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any + +import gymnasium as gym +import metaworld +import metaworld.policies as policies +import numpy as np +from gymnasium import spaces + +# ---- Load configuration data from the external JSON file ---- +CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" +try: + with open(CONFIG_PATH) as f: + data = json.load(f) +except FileNotFoundError as err: + raise FileNotFoundError( + "Could not find 'metaworld_config.json'. " + "Please ensure the configuration file is in the same directory as the script." + ) from err +except json.JSONDecodeError as err: + raise ValueError( + "Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file." + ) from err + +# ---- Process the loaded data ---- + +# extract and type-check top-level dicts +task_descriptions_obj = data.get("TASK_DESCRIPTIONS") +if not isinstance(task_descriptions_obj, dict): + raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]") +TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj + +task_name_to_id_obj = data.get("TASK_NAME_TO_ID") +if not isinstance(task_name_to_id_obj, dict): + raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]") +TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj + +# difficulty -> tasks mapping +difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS") +if not isinstance(difficulty_to_tasks, dict): + raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]") +DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks + +# convert policy strings -> actual policy classes +task_policy_mapping = data.get("TASK_POLICY_MAPPING") +if not isinstance(task_policy_mapping, dict): + raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]") +TASK_POLICY_MAPPING: dict[str, Any] = { + task_name: getattr(policies, policy_class_name) + for task_name, policy_class_name in task_policy_mapping.items() +} +ACTION_DIM = 4 +OBS_DIM = 4 + + +class MetaworldEnv(gym.Env): + metadata = {"render_modes": ["rgb_array"], "render_fps": 80} + + def __init__( + self, + task, + camera_name="corner2", + obs_type="pixels", + render_mode="rgb_array", + observation_width=480, + observation_height=480, + visualization_width=640, + visualization_height=480, + ): + super().__init__() + self.task = task.replace("metaworld-", "") + self.obs_type = obs_type + self.render_mode = render_mode + self.observation_width = observation_width + self.observation_height = observation_height + self.visualization_width = visualization_width + self.visualization_height = visualization_height + self.camera_name = camera_name + + self._env = self._make_envs_task(self.task) + self._max_episode_steps = self._env.max_path_length + self.task_description = TASK_DESCRIPTIONS[self.task] + + self.expert_policy = TASK_POLICY_MAPPING[self.task]() + + if self.obs_type == "state": + raise NotImplementedError() + elif self.obs_type == "pixels": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ) + } + ) + elif self.obs_type == "pixels_agent_pos": + self.observation_space = spaces.Dict( + { + "pixels": spaces.Box( + low=0, + high=255, + shape=(self.observation_height, self.observation_width, 3), + dtype=np.uint8, + ), + "agent_pos": spaces.Box( + low=-1000.0, + high=1000.0, + shape=(OBS_DIM,), + dtype=np.float64, + ), + } + ) + + self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32) + + def render(self) -> np.ndarray: + """ + Render the current environment frame. + + Returns: + np.ndarray: The rendered RGB image from the environment. + """ + image = self._env.render() + if self.camera_name == "corner2": + # Images from this camera are flipped — correct them + image = np.flip(image, (0, 1)) + return image + + def _make_envs_task(self, env_name: str): + mt1 = metaworld.MT1(env_name, seed=42) + env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name) + env.set_task(mt1.train_tasks[0]) + if self.camera_name == "corner2": + env.model.cam_pos[2] = [ + 0.75, + 0.075, + 0.7, + ] # corner2 position, similar to https://arxiv.org/pdf/2206.14244 + env.reset() + env._freeze_rand_vec = False # otherwise no randomization + return env + + def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]: + image = None + if self._env is not None: + image = self._env.render() + if self.camera_name == "corner2": + # NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted. + image = np.flip(image, (0, 1)) + agent_pos = raw_obs[:4] + if self.obs_type == "state": + raise NotImplementedError( + "'state' obs_type not implemented for MetaWorld. Use pixel modes instead." + ) + + elif self.obs_type in ("pixels", "pixels_agent_pos"): + assert image is not None, ( + "Expected `image` to be rendered before constructing pixel-based observations. " + "This likely means `env.render()` returned None or the environment was not provided." + ) + + if self.obs_type == "pixels": + obs = {"pixels": image.copy()} + + else: # pixels_agent_pos + obs = { + "pixels": image.copy(), + "agent_pos": agent_pos, + } + else: + raise ValueError(f"Unknown obs_type: {self.obs_type}") + return obs + + def reset( + self, + seed: int | None = None, + **kwargs, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Reset the environment to its initial state. + + Args: + seed (Optional[int]): Random seed for environment initialization. + + Returns: + observation (Dict[str, Any]): The initial formatted observation. + info (Dict[str, Any]): Additional info about the reset state. + """ + super().reset(seed=seed) + + raw_obs, info = self._env.reset(seed=seed) + + observation = self._format_raw_obs(raw_obs) + + info = {"is_success": False} + return observation, info + + def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """ + Perform one environment step. + + Args: + action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). + + Returns: + observation (Dict[str, Any]): The formatted observation after the step. + reward (float): The scalar reward for this step. + terminated (bool): Whether the episode terminated successfully. + truncated (bool): Whether the episode was truncated due to a time limit. + info (Dict[str, Any]): Additional environment info. + """ + if action.ndim != 1: + raise ValueError( + f"Expected action to be 1-D (shape (action_dim,)), " + f"but got shape {action.shape} with ndim={action.ndim}" + ) + raw_obs, reward, done, truncated, info = self._env.step(action) + + # Determine whether the task was successful + is_success = bool(info.get("success", 0)) + terminated = done or is_success + info.update( + { + "task": self.task, + "done": done, + "is_success": is_success, + } + ) + + # Format the raw observation into the expected structure + observation = self._format_raw_obs(raw_obs) + if terminated: + info["final_info"] = { + "task": self.task, + "done": bool(done), + "is_success": bool(is_success), + } + self.reset() + + return observation, reward, terminated, truncated, info + + def close(self): + self._env.close() + + +# ---- Main API ---------------------------------------------------------------- + + +def create_metaworld_envs( + task: str, + n_envs: int, + gym_kwargs: dict[str, Any] | None = None, + env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None, +) -> dict[str, dict[int, Any]]: + """ + Create vectorized Meta-World environments with a consistent return shape. + + Returns: + dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories) + Notes: + - n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1). + - `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list. + - If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task. + """ + if env_cls is None or not callable(env_cls): + raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.") + if not isinstance(n_envs, int) or n_envs <= 0: + raise ValueError(f"n_envs must be a positive int; got {n_envs}.") + + gym_kwargs = dict(gym_kwargs or {}) + task_groups = [t.strip() for t in task.split(",") if t.strip()] + if not task_groups: + raise ValueError("`task` must contain at least one Meta-World task or difficulty group.") + + print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}") + + out: dict[str, dict[int, Any]] = defaultdict(dict) + + for group in task_groups: + # if not in difficulty presets, treat it as a single custom task + tasks = DIFFICULTY_TO_TASKS.get(group, [group]) + + for tid, task_name in enumerate(tasks): + print(f"Building vec env | group={group} | task_id={tid} | task={task_name}") + + # build n_envs factories + fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)] + + out[group][tid] = env_cls(fns) + + # return a plain dict for consistency + return {group: dict(task_map) for group, task_map in out.items()} diff --git a/src/lerobot/envs/metaworld_config.json b/src/lerobot/envs/metaworld_config.json new file mode 100644 index 000000000..41a417fef --- /dev/null +++ b/src/lerobot/envs/metaworld_config.json @@ -0,0 +1,121 @@ +{ + "TASK_DESCRIPTIONS": { + "assembly-v3": "Pick up a nut and place it onto a peg", + "basketball-v3": "Dunk the basketball into the basket", + "bin-picking-v3": "Grasp the puck from one bin and place it into another bin", + "box-close-v3": "Grasp the cover and close the box with it", + "button-press-topdown-v3": "Press a button from the top", + "button-press-topdown-wall-v3": "Bypass a wall and press a button from the top", + "button-press-v3": "Press a button", + "button-press-wall-v3": "Bypass a wall and press a button", + "coffee-button-v3": "Push a button on the coffee machine", + "coffee-pull-v3": "Pull a mug from a coffee machine", + "coffee-push-v3": "Push a mug under a coffee machine", + "dial-turn-v3": "Rotate a dial 180 degrees", + "disassemble-v3": "Pick a nut out of a peg", + "door-close-v3": "Close a door with a revolving joint", + "door-lock-v3": "Lock the door by rotating the lock clockwise", + "door-open-v3": "Open a door with a revolving joint", + "door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise", + "hand-insert-v3": "Insert the gripper into a hole", + "drawer-close-v3": "Push and close a drawer", + "drawer-open-v3": "Open a drawer", + "faucet-open-v3": "Rotate the faucet counter-clockwise", + "faucet-close-v3": "Rotate the faucet clockwise", + "hammer-v3": "Hammer a screw on the wall", + "handle-press-side-v3": "Press a handle down sideways", + "handle-press-v3": "Press a handle down", + "handle-pull-side-v3": "Pull a handle up sideways", + "handle-pull-v3": "Pull a handle up", + "lever-pull-v3": "Pull a lever down 90 degrees", + "peg-insert-side-v3": "Insert a peg sideways", + "pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck", + "pick-out-of-hole-v3": "Pick up a puck from a hole", + "reach-v3": "Reach a goal position", + "push-back-v3": "Push the puck to a goal", + "push-v3": "Push the puck to a goal", + "pick-place-v3": "Pick and place a puck to a goal", + "plate-slide-v3": "Slide a plate into a cabinet", + "plate-slide-side-v3": "Slide a plate into a cabinet sideways", + "plate-slide-back-v3": "Get a plate from the cabinet", + "plate-slide-back-side-v3": "Get a plate from the cabinet sideways", + "peg-unplug-side-v3": "Unplug a peg sideways", + "soccer-v3": "Kick a soccer into the goal", + "stick-push-v3": "Grasp a stick and push a box using the stick", + "stick-pull-v3": "Grasp a stick and pull a box with the stick", + "push-wall-v3": "Bypass a wall and push a puck to a goal", + "reach-wall-v3": "Bypass a wall and reach a goal", + "shelf-place-v3": "Pick and place a puck onto a shelf", + "sweep-into-v3": "Sweep a puck into a hole", + "sweep-v3": "Sweep a puck off the table", + "window-open-v3": "Push and open a window", + "window-close-v3": "Push and close a window" + }, + "TASK_NAME_TO_ID": { + "assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3, + "button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6, + "button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10, + "dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14, + "door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18, + "faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22, + "handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25, + "handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29, + "pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32, + "plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35, + "plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40, + "reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44, + "stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48, + "window-close-v3": 49 + }, + "DIFFICULTY_TO_TASKS": { + "easy": [ + "button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3", + "button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3", + "door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3", + "faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3", + "handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3", + "plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3", + "reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3" + ], + "medium": [ + "basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3", + "hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3" + ], + "hard": [ + "assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3" + ], + "very_hard": [ + "shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3" + ] + }, + "TASK_POLICY_MAPPING": { + "assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy", + "bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy", + "button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy", + "button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy", + "button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy", + "coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy", + "coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy", + "disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy", + "door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy", + "door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy", + "drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy", + "faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy", + "hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy", + "handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy", + "handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy", + "peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy", + "pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy", + "pick-place-wall-v3": "SawyerPickPlaceWallV3Policy", + "plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy", + "plate-slide-back-v3": "SawyerPlateSlideBackV3Policy", + "plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy", + "push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy", + "push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy", + "reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy", + "soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy", + "stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy", + "sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy", + "window-close-v3": "SawyerWindowCloseV3Policy" + } +} diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d45be5c42..aed7d32e3 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -180,9 +180,15 @@ def rollout( render_callback(env) # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't - # available of none of the envs finished. + # available if none of the envs finished. if "final_info" in info: - successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + final_info = info["final_info"] + if not isinstance(final_info, dict): + raise RuntimeError( + "Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). " + "You're likely using an older version of gymnasium (< 1.0). Please upgrade." + ) + successes = final_info["is_success"].tolist() else: successes = [False] * env.num_envs diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 34fa89390..345526d90 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str): @pytest.mark.parametrize( "ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs", [ - ("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}), ("lerobot/pusht", "pusht", {}, "diffusion", {}), ("lerobot/pusht", "pusht", {}, "vqbet", {}), ("lerobot/pusht", "pusht", {}, "act", {}), @@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool): # TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it # was changed to true. For some reason, tests would pass locally, but not in CI. So here we override # to test with `policy.use_mpc=false`. - ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"), - # ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"), # TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to # to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference # that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.