forked from tangger/lerobot
feat(sim): add metaworld env (#2088)
* add metaworld * smol update Signed-off-by: Jade Choghari <chogharijade@gmail.com> * update design * Update src/lerobot/envs/metaworld.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari <chogharijade@gmail.com> * update * small changes * iterate on review * small fix * small fix * add docs * update doc * add better gif * smol doc fix * updage gymnasium * add note * depreciate gym-xarm * more changes * update doc * comply with mypy * more fixes * update readme * precommit * update pusht * add pusht instead * changes * style * add changes * update * revert * update v2 * chore(envs): move metaworld config to its own file + remove comments + simplify _format_raw_obs (#2200) * update final changes --------- Signed-off-by: Jade Choghari <chogharijade@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -72,7 +72,6 @@ post it.
|
|||||||
|
|
||||||
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
Look at our implementations for [datasets](./src/lerobot/datasets/), [policies](./src/lerobot/policies/),
|
||||||
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||||
[xarm](https://github.com/huggingface/gym-xarm),
|
|
||||||
[pusht](https://github.com/huggingface/gym-pusht))
|
[pusht](https://github.com/huggingface/gym-pusht))
|
||||||
and follow the same api design.
|
and follow the same api design.
|
||||||
|
|
||||||
|
|||||||
10
Makefile
10
Makefile
@@ -119,10 +119,9 @@ test-tdmpc-ete-train:
|
|||||||
--policy.type=tdmpc \
|
--policy.type=tdmpc \
|
||||||
--policy.device=$(DEVICE) \
|
--policy.device=$(DEVICE) \
|
||||||
--policy.push_to_hub=false \
|
--policy.push_to_hub=false \
|
||||||
--env.type=xarm \
|
--env.type=pusht \
|
||||||
--env.task=XarmLift-v0 \
|
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--dataset.repo_id=lerobot/xarm_lift_medium \
|
--dataset.repo_id=lerobot/pusht_image \
|
||||||
--dataset.image_transforms.enable=true \
|
--dataset.image_transforms.enable=true \
|
||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
@@ -140,9 +139,10 @@ test-tdmpc-ete-eval:
|
|||||||
lerobot-eval \
|
lerobot-eval \
|
||||||
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
|
||||||
--policy.device=$(DEVICE) \
|
--policy.device=$(DEVICE) \
|
||||||
--env.type=xarm \
|
--env.type=pusht \
|
||||||
--env.episode_length=5 \
|
--env.episode_length=5 \
|
||||||
--env.task=XarmLift-v0 \
|
--env.observation_height=96 \
|
||||||
|
--env.observation_width=96 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1
|
--eval.batch_size=1
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,6 @@
|
|||||||
- sections:
|
- sections:
|
||||||
- local: il_robots
|
- local: il_robots
|
||||||
title: Imitation Learning for Robots
|
title: Imitation Learning for Robots
|
||||||
- local: il_sim
|
|
||||||
title: Imitation Learning in Sim
|
|
||||||
- local: cameras
|
- local: cameras
|
||||||
title: Cameras
|
title: Cameras
|
||||||
- local: integrate_hardware
|
- local: integrate_hardware
|
||||||
@@ -37,9 +35,15 @@
|
|||||||
title: π₀ (Pi0)
|
title: π₀ (Pi0)
|
||||||
- local: pi05
|
- local: pi05
|
||||||
title: π₀.₅ (Pi05)
|
title: π₀.₅ (Pi05)
|
||||||
|
title: "Policies"
|
||||||
|
- sections:
|
||||||
|
- local: il_sim
|
||||||
|
title: Imitation Learning in Sim
|
||||||
- local: libero
|
- local: libero
|
||||||
title: Using Libero
|
title: Using Libero
|
||||||
title: "Policies"
|
- local: metaworld
|
||||||
|
title: Using MetaWorld
|
||||||
|
title: "Simulation"
|
||||||
- sections:
|
- sections:
|
||||||
- local: introduction_processors
|
- local: introduction_processors
|
||||||
title: Introduction to Robot Processors
|
title: Introduction to Robot Processors
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ LeRobot provides optional extras for specific functionalities. Multiple extras c
|
|||||||
|
|
||||||
### Simulations
|
### Simulations
|
||||||
|
|
||||||
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
|
Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht))
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ The finetuned model can be found here:
|
|||||||
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/lerobot/scripts/eval.py \
|
lerobot-eval \
|
||||||
--output_dir=/logs/ \
|
--output_dir=/logs/ \
|
||||||
--env.type=libero \
|
--env.type=libero \
|
||||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||||
|
|||||||
80
docs/source/metaworld.mdx
Normal file
80
docs/source/metaworld.mdx
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
# Meta-World
|
||||||
|
|
||||||
|
Meta-World is a well-designed, open-source simulation benchmark for multi-task and meta reinforcement learning in continuous-control robotic manipulation. It gives researchers a shared, realistic playground to test whether algorithms can _learn many different tasks_ and _generalize quickly to new ones_ — two central challenges for real-world robotics.
|
||||||
|
|
||||||
|
- 📄 [MetaWorld paper](https://arxiv.org/pdf/1910.10897)
|
||||||
|
- 💻 [Original MetaWorld repo](https://github.com/Farama-Foundation/Metaworld)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Why Meta-World matters
|
||||||
|
|
||||||
|
- **Diverse, realistic tasks.** Meta-World bundles a large suite of simulated manipulation tasks (50 in the MT50 suite) using everyday objects and a common tabletop Sawyer arm. This diversity exposes algorithms to a wide variety of dynamics, contacts and goal specifications while keeping a consistent control and observation structure.
|
||||||
|
- **Focus on generalization and multi-task learning.** By evaluating across task distributions that share structure but differ in goals and objects, Meta-World reveals whether an agent truly learns transferable skills rather than overfitting to a narrow task.
|
||||||
|
- **Standardized evaluation protocol.** It provides clear evaluation modes and difficulty splits, so different methods can be compared fairly across easy, medium, hard and very-hard regimes.
|
||||||
|
- **Empirical insight.** Past evaluations on Meta-World show impressive progress on some fronts, but also highlight that current multi-task and meta-RL methods still struggle with large, diverse task sets. That gap points to important research directions.
|
||||||
|
|
||||||
|
## What it enables in LeRobot
|
||||||
|
|
||||||
|
In LeRobot, you can evaluate any policy or vision-language-action (VLA) model on Meta-World tasks and get a clear success-rate measure. The integration is designed to be straightforward:
|
||||||
|
|
||||||
|
- We provide a LeRobot-ready dataset for Meta-World (MT50) on the HF Hub: `https://huggingface.co/datasets/lerobot/metaworld_mt50`.
|
||||||
|
- This dataset is formatted for the MT50 evaluation that uses all 50 tasks (the most challenging multi-task setting).
|
||||||
|
- MT50 gives the policy a one-hot task vector and uses fixed object/goal positions for consistency.
|
||||||
|
|
||||||
|
- Task descriptions and the exact keys required for evaluation are available in the repo/dataset — use these to ensure your policy outputs the right success signals.
|
||||||
|
|
||||||
|
## Quick start, train a SmolVLA policy on Meta-World
|
||||||
|
|
||||||
|
Example command to train a SmolVLA policy on a subset of tasks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.type=smolvla \
|
||||||
|
--policy.repo_id=${HF_USER}/metaworld-test \
|
||||||
|
--policy.load_vlm_weights=true \
|
||||||
|
--dataset.repo_id=lerobot/metaworld_mt50 \
|
||||||
|
--env.type=metaworld \
|
||||||
|
--env.task=assembly-v3,dial-turn-v3,handle-press-side-v3 \
|
||||||
|
--output_dir=./outputs/ \
|
||||||
|
--steps=100000 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=1 \
|
||||||
|
--eval_freq=1000
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- `--env.task` accepts explicit task lists (comma separated) or difficulty groups (e.g., `env.task="hard"`).
|
||||||
|
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
||||||
|
- **Gymnasium Assertion Error**: if you encounter an error like
|
||||||
|
`AssertionError: ['human', 'rgb_array', 'depth_array']` when running MetaWorld environments, this comes from a mismatch between MetaWorld and your Gymnasium version.
|
||||||
|
We recommend using:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install "gymnasium==1.1.0"
|
||||||
|
```
|
||||||
|
|
||||||
|
to ensure proper compatibility.
|
||||||
|
|
||||||
|
## Quick start — evaluate a trained policy
|
||||||
|
|
||||||
|
To evaluate a trained policy on the Meta-World medium difficulty split:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-eval \
|
||||||
|
--policy.path="your-policy-id" \
|
||||||
|
--env.type=metaworld \
|
||||||
|
--env.task=medium \
|
||||||
|
--eval.batch_size=1 \
|
||||||
|
--eval.n_episodes=2
|
||||||
|
```
|
||||||
|
|
||||||
|
This will run episodes and return per-task success rates using the standard Meta-World evaluation keys.
|
||||||
|
|
||||||
|
## Practical tips
|
||||||
|
|
||||||
|
- If you care about generalization, run on the full MT50 suite — it’s intentionally challenging and reveals strengths/weaknesses better than a few narrow tasks.
|
||||||
|
- Use the one-hot task conditioning for multi-task training (MT10 / MT50 conventions) so policies have explicit task context.
|
||||||
|
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||||
@@ -80,7 +80,7 @@ dependencies = [
|
|||||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||||
|
|
||||||
"draccus==0.10.0", # TODO: Remove ==
|
"draccus==0.10.0", # TODO: Remove ==
|
||||||
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
|
"gymnasium>=1.0.0",
|
||||||
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||||
|
|
||||||
# Support dependencies
|
# Support dependencies
|
||||||
@@ -133,11 +133,10 @@ test = ["pytest>=8.1.0,<9.0.0", "pytest-timeout>=2.4.0,<3.0.0", "pytest-cov>=5.0
|
|||||||
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||||
|
|
||||||
# Simulation
|
# Simulation
|
||||||
aloha = ["gym-aloha>=0.1.1,<0.2.0"]
|
aloha = ["gym-aloha>=0.1.2,<0.2.0"]
|
||||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||||
xarm = ["gym-xarm>=0.1.1,<0.2.0"]
|
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@upgrade-dep#egg=libero"]
|
||||||
libero = ["lerobot[transformers-dep]", "libero @ git+https://github.com/huggingface/lerobot-libero.git@main#egg=libero"]
|
metaworld = ["metaworld>=3.0.0"]
|
||||||
|
|
||||||
|
|
||||||
# All
|
# All
|
||||||
all = [
|
all = [
|
||||||
@@ -157,9 +156,9 @@ all = [
|
|||||||
"lerobot[video_benchmark]",
|
"lerobot[video_benchmark]",
|
||||||
"lerobot[aloha]",
|
"lerobot[aloha]",
|
||||||
"lerobot[pusht]",
|
"lerobot[pusht]",
|
||||||
"lerobot[xarm]",
|
|
||||||
"lerobot[phone]",
|
"lerobot[phone]",
|
||||||
"lerobot[libero]",
|
"lerobot[libero]",
|
||||||
|
"lerobot[metaworld]",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ available_tasks_per_env = {
|
|||||||
"AlohaTransferCube-v0",
|
"AlohaTransferCube-v0",
|
||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
@@ -75,16 +74,6 @@ available_datasets_per_env = {
|
|||||||
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
|
||||||
# coupled with tests.
|
# coupled with tests.
|
||||||
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
|
||||||
"xarm": [
|
|
||||||
"lerobot/xarm_lift_medium",
|
|
||||||
"lerobot/xarm_lift_medium_replay",
|
|
||||||
"lerobot/xarm_push_medium",
|
|
||||||
"lerobot/xarm_push_medium_replay",
|
|
||||||
"lerobot/xarm_lift_medium_image",
|
|
||||||
"lerobot/xarm_lift_medium_replay_image",
|
|
||||||
"lerobot/xarm_push_medium_image",
|
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
available_real_world_datasets = [
|
available_real_world_datasets = [
|
||||||
@@ -195,7 +184,6 @@ available_motors = [
|
|||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
"pusht": ["diffusion", "vqbet"],
|
"pusht": ["diffusion", "vqbet"],
|
||||||
"xarm": ["tdmpc"],
|
|
||||||
"koch_real": ["act_koch_real"],
|
"koch_real": ["act_koch_real"],
|
||||||
"aloha_real": ["act_aloha_real"],
|
"aloha_real": ["act_aloha_real"],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,4 +12,4 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401
|
||||||
|
|||||||
@@ -133,45 +133,6 @@ class PushtEnv(EnvConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@EnvConfig.register_subclass("xarm")
|
|
||||||
@dataclass
|
|
||||||
class XarmEnv(EnvConfig):
|
|
||||||
task: str | None = "XarmLift-v0"
|
|
||||||
fps: int = 15
|
|
||||||
episode_length: int = 200
|
|
||||||
obs_type: str = "pixels_agent_pos"
|
|
||||||
render_mode: str = "rgb_array"
|
|
||||||
visualization_width: int = 384
|
|
||||||
visualization_height: int = 384
|
|
||||||
features: dict[str, PolicyFeature] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
|
||||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
features_map: dict[str, str] = field(
|
|
||||||
default_factory=lambda: {
|
|
||||||
ACTION: ACTION,
|
|
||||||
"agent_pos": OBS_STATE,
|
|
||||||
"pixels": OBS_IMAGE,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.obs_type == "pixels_agent_pos":
|
|
||||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def gym_kwargs(self) -> dict:
|
|
||||||
return {
|
|
||||||
"obs_type": self.obs_type,
|
|
||||||
"render_mode": self.render_mode,
|
|
||||||
"visualization_width": self.visualization_width,
|
|
||||||
"visualization_height": self.visualization_height,
|
|
||||||
"max_episode_steps": self.episode_length,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImagePreprocessingConfig:
|
class ImagePreprocessingConfig:
|
||||||
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
crop_params_dict: dict[str, tuple[int, int, int, int]] | None = None
|
||||||
@@ -306,3 +267,45 @@ class LiberoEnv(EnvConfig):
|
|||||||
"obs_type": self.obs_type,
|
"obs_type": self.obs_type,
|
||||||
"render_mode": self.render_mode,
|
"render_mode": self.render_mode,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("metaworld")
|
||||||
|
@dataclass
|
||||||
|
class MetaworldEnv(EnvConfig):
|
||||||
|
task: str = "metaworld-push-v2" # add all tasks
|
||||||
|
fps: int = 80
|
||||||
|
episode_length: int = 400
|
||||||
|
obs_type: str = "pixels_agent_pos"
|
||||||
|
render_mode: str = "rgb_array"
|
||||||
|
multitask_eval: bool = True
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": ACTION,
|
||||||
|
"agent_pos": OBS_STATE,
|
||||||
|
"top": f"{OBS_IMAGE}",
|
||||||
|
"pixels/top": f"{OBS_IMAGE}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.obs_type == "pixels":
|
||||||
|
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
||||||
|
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||||
|
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 480, 3))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {
|
||||||
|
"obs_type": self.obs_type,
|
||||||
|
"render_mode": self.render_mode,
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import importlib
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv, XarmEnv
|
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
@@ -25,8 +25,6 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
return AlohaEnv(**kwargs)
|
return AlohaEnv(**kwargs)
|
||||||
elif env_type == "pusht":
|
elif env_type == "pusht":
|
||||||
return PushtEnv(**kwargs)
|
return PushtEnv(**kwargs)
|
||||||
elif env_type == "xarm":
|
|
||||||
return XarmEnv(**kwargs)
|
|
||||||
elif env_type == "libero":
|
elif env_type == "libero":
|
||||||
return LiberoEnv(**kwargs)
|
return LiberoEnv(**kwargs)
|
||||||
else:
|
else:
|
||||||
@@ -74,7 +72,18 @@ def make_env(
|
|||||||
gym_kwargs=cfg.gym_kwargs,
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
env_cls=env_cls,
|
env_cls=env_cls,
|
||||||
)
|
)
|
||||||
|
elif "metaworld" in cfg.type:
|
||||||
|
from lerobot.envs.metaworld import create_metaworld_envs
|
||||||
|
|
||||||
|
if cfg.task is None:
|
||||||
|
raise ValueError("MetaWorld requires a task to be specified")
|
||||||
|
|
||||||
|
return create_metaworld_envs(
|
||||||
|
task=cfg.task,
|
||||||
|
n_envs=n_envs,
|
||||||
|
gym_kwargs=cfg.gym_kwargs,
|
||||||
|
env_cls=env_cls,
|
||||||
|
)
|
||||||
package_name = f"gym_{cfg.type}"
|
package_name = f"gym_{cfg.type}"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
@@ -87,7 +96,7 @@ def make_env(
|
|||||||
def _make_one():
|
def _make_one():
|
||||||
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
return gym.make(gym_handle, disable_env_checker=cfg.disable_env_checker, **(cfg.gym_kwargs or {}))
|
||||||
|
|
||||||
vec = env_cls([_make_one for _ in range(n_envs)])
|
vec = env_cls([_make_one for _ in range(n_envs)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP)
|
||||||
|
|
||||||
# normalize to {suite: {task_id: vec_env}} for consistency
|
# normalize to {suite: {task_id: vec_env}} for consistency
|
||||||
suite_name = cfg.type # e.g., "pusht", "aloha"
|
suite_name = cfg.type # e.g., "pusht", "aloha"
|
||||||
|
|||||||
@@ -260,19 +260,23 @@ class LiberoEnv(gym.Env):
|
|||||||
|
|
||||||
is_success = self._env.check_success()
|
is_success = self._env.check_success()
|
||||||
terminated = done or is_success
|
terminated = done or is_success
|
||||||
info["is_success"] = is_success
|
info.update(
|
||||||
|
{
|
||||||
|
"task": self.task,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"done": done,
|
||||||
|
"is_success": is_success,
|
||||||
|
}
|
||||||
|
)
|
||||||
observation = self._format_raw_obs(raw_obs)
|
observation = self._format_raw_obs(raw_obs)
|
||||||
if done:
|
if terminated:
|
||||||
|
info["final_info"] = {
|
||||||
|
"task": self.task,
|
||||||
|
"task_id": self.task_id,
|
||||||
|
"done": bool(done),
|
||||||
|
"is_success": bool(is_success),
|
||||||
|
}
|
||||||
self.reset()
|
self.reset()
|
||||||
info.update(
|
|
||||||
{
|
|
||||||
"task": self.task,
|
|
||||||
"task_id": self.task_id,
|
|
||||||
"done": done,
|
|
||||||
"is_success": is_success,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
truncated = False
|
truncated = False
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|||||||
313
src/lerobot/envs/metaworld.py
Normal file
313
src/lerobot/envs/metaworld.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable, Sequence
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import metaworld
|
||||||
|
import metaworld.policies as policies
|
||||||
|
import numpy as np
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
# ---- Load configuration data from the external JSON file ----
|
||||||
|
CONFIG_PATH = Path(__file__).parent / "metaworld_config.json"
|
||||||
|
try:
|
||||||
|
with open(CONFIG_PATH) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
except FileNotFoundError as err:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
"Could not find 'metaworld_config.json'. "
|
||||||
|
"Please ensure the configuration file is in the same directory as the script."
|
||||||
|
) from err
|
||||||
|
except json.JSONDecodeError as err:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to decode 'metaworld_config.json'. Please ensure it is a valid JSON file."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
# ---- Process the loaded data ----
|
||||||
|
|
||||||
|
# extract and type-check top-level dicts
|
||||||
|
task_descriptions_obj = data.get("TASK_DESCRIPTIONS")
|
||||||
|
if not isinstance(task_descriptions_obj, dict):
|
||||||
|
raise TypeError("Expected TASK_DESCRIPTIONS to be a dict[str, str]")
|
||||||
|
TASK_DESCRIPTIONS: dict[str, str] = task_descriptions_obj
|
||||||
|
|
||||||
|
task_name_to_id_obj = data.get("TASK_NAME_TO_ID")
|
||||||
|
if not isinstance(task_name_to_id_obj, dict):
|
||||||
|
raise TypeError("Expected TASK_NAME_TO_ID to be a dict[str, int]")
|
||||||
|
TASK_NAME_TO_ID: dict[str, int] = task_name_to_id_obj
|
||||||
|
|
||||||
|
# difficulty -> tasks mapping
|
||||||
|
difficulty_to_tasks = data.get("DIFFICULTY_TO_TASKS")
|
||||||
|
if not isinstance(difficulty_to_tasks, dict):
|
||||||
|
raise TypeError("Expected 'DIFFICULTY_TO_TASKS' to be a dict[str, list[str]]")
|
||||||
|
DIFFICULTY_TO_TASKS: dict[str, list[str]] = difficulty_to_tasks
|
||||||
|
|
||||||
|
# convert policy strings -> actual policy classes
|
||||||
|
task_policy_mapping = data.get("TASK_POLICY_MAPPING")
|
||||||
|
if not isinstance(task_policy_mapping, dict):
|
||||||
|
raise TypeError("Expected 'TASK_POLICY_MAPPING' to be a dict[str, str]")
|
||||||
|
TASK_POLICY_MAPPING: dict[str, Any] = {
|
||||||
|
task_name: getattr(policies, policy_class_name)
|
||||||
|
for task_name, policy_class_name in task_policy_mapping.items()
|
||||||
|
}
|
||||||
|
ACTION_DIM = 4
|
||||||
|
OBS_DIM = 4
|
||||||
|
|
||||||
|
|
||||||
|
class MetaworldEnv(gym.Env):
|
||||||
|
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task,
|
||||||
|
camera_name="corner2",
|
||||||
|
obs_type="pixels",
|
||||||
|
render_mode="rgb_array",
|
||||||
|
observation_width=480,
|
||||||
|
observation_height=480,
|
||||||
|
visualization_width=640,
|
||||||
|
visualization_height=480,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.task = task.replace("metaworld-", "")
|
||||||
|
self.obs_type = obs_type
|
||||||
|
self.render_mode = render_mode
|
||||||
|
self.observation_width = observation_width
|
||||||
|
self.observation_height = observation_height
|
||||||
|
self.visualization_width = visualization_width
|
||||||
|
self.visualization_height = visualization_height
|
||||||
|
self.camera_name = camera_name
|
||||||
|
|
||||||
|
self._env = self._make_envs_task(self.task)
|
||||||
|
self._max_episode_steps = self._env.max_path_length
|
||||||
|
self.task_description = TASK_DESCRIPTIONS[self.task]
|
||||||
|
|
||||||
|
self.expert_policy = TASK_POLICY_MAPPING[self.task]()
|
||||||
|
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif self.obs_type == "pixels":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.obs_type == "pixels_agent_pos":
|
||||||
|
self.observation_space = spaces.Dict(
|
||||||
|
{
|
||||||
|
"pixels": spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(self.observation_height, self.observation_width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
),
|
||||||
|
"agent_pos": spaces.Box(
|
||||||
|
low=-1000.0,
|
||||||
|
high=1000.0,
|
||||||
|
shape=(OBS_DIM,),
|
||||||
|
dtype=np.float64,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.action_space = spaces.Box(low=-1, high=1, shape=(ACTION_DIM,), dtype=np.float32)
|
||||||
|
|
||||||
|
def render(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Render the current environment frame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: The rendered RGB image from the environment.
|
||||||
|
"""
|
||||||
|
image = self._env.render()
|
||||||
|
if self.camera_name == "corner2":
|
||||||
|
# Images from this camera are flipped — correct them
|
||||||
|
image = np.flip(image, (0, 1))
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _make_envs_task(self, env_name: str):
|
||||||
|
mt1 = metaworld.MT1(env_name, seed=42)
|
||||||
|
env = mt1.train_classes[env_name](render_mode="rgb_array", camera_name=self.camera_name)
|
||||||
|
env.set_task(mt1.train_tasks[0])
|
||||||
|
if self.camera_name == "corner2":
|
||||||
|
env.model.cam_pos[2] = [
|
||||||
|
0.75,
|
||||||
|
0.075,
|
||||||
|
0.7,
|
||||||
|
] # corner2 position, similar to https://arxiv.org/pdf/2206.14244
|
||||||
|
env.reset()
|
||||||
|
env._freeze_rand_vec = False # otherwise no randomization
|
||||||
|
return env
|
||||||
|
|
||||||
|
def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]:
|
||||||
|
image = None
|
||||||
|
if self._env is not None:
|
||||||
|
image = self._env.render()
|
||||||
|
if self.camera_name == "corner2":
|
||||||
|
# NOTE: The "corner2" camera in MetaWorld environments outputs images with both axes inverted.
|
||||||
|
image = np.flip(image, (0, 1))
|
||||||
|
agent_pos = raw_obs[:4]
|
||||||
|
if self.obs_type == "state":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"'state' obs_type not implemented for MetaWorld. Use pixel modes instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.obs_type in ("pixels", "pixels_agent_pos"):
|
||||||
|
assert image is not None, (
|
||||||
|
"Expected `image` to be rendered before constructing pixel-based observations. "
|
||||||
|
"This likely means `env.render()` returned None or the environment was not provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.obs_type == "pixels":
|
||||||
|
obs = {"pixels": image.copy()}
|
||||||
|
|
||||||
|
else: # pixels_agent_pos
|
||||||
|
obs = {
|
||||||
|
"pixels": image.copy(),
|
||||||
|
"agent_pos": agent_pos,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown obs_type: {self.obs_type}")
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
seed: int | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Reset the environment to its initial state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (Optional[int]): Random seed for environment initialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
observation (Dict[str, Any]): The initial formatted observation.
|
||||||
|
info (Dict[str, Any]): Additional info about the reset state.
|
||||||
|
"""
|
||||||
|
super().reset(seed=seed)
|
||||||
|
|
||||||
|
raw_obs, info = self._env.reset(seed=seed)
|
||||||
|
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
|
||||||
|
info = {"is_success": False}
|
||||||
|
return observation, info
|
||||||
|
|
||||||
|
def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Perform one environment step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
observation (Dict[str, Any]): The formatted observation after the step.
|
||||||
|
reward (float): The scalar reward for this step.
|
||||||
|
terminated (bool): Whether the episode terminated successfully.
|
||||||
|
truncated (bool): Whether the episode was truncated due to a time limit.
|
||||||
|
info (Dict[str, Any]): Additional environment info.
|
||||||
|
"""
|
||||||
|
if action.ndim != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||||
|
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||||
|
)
|
||||||
|
raw_obs, reward, done, truncated, info = self._env.step(action)
|
||||||
|
|
||||||
|
# Determine whether the task was successful
|
||||||
|
is_success = bool(info.get("success", 0))
|
||||||
|
terminated = done or is_success
|
||||||
|
info.update(
|
||||||
|
{
|
||||||
|
"task": self.task,
|
||||||
|
"done": done,
|
||||||
|
"is_success": is_success,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format the raw observation into the expected structure
|
||||||
|
observation = self._format_raw_obs(raw_obs)
|
||||||
|
if terminated:
|
||||||
|
info["final_info"] = {
|
||||||
|
"task": self.task,
|
||||||
|
"done": bool(done),
|
||||||
|
"is_success": bool(is_success),
|
||||||
|
}
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._env.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Main API ----------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def create_metaworld_envs(
|
||||||
|
task: str,
|
||||||
|
n_envs: int,
|
||||||
|
gym_kwargs: dict[str, Any] | None = None,
|
||||||
|
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||||
|
) -> dict[str, dict[int, Any]]:
|
||||||
|
"""
|
||||||
|
Create vectorized Meta-World environments with a consistent return shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[task_group][task_id] -> vec_env (env_cls([...]) with exactly n_envs factories)
|
||||||
|
Notes:
|
||||||
|
- n_envs is the number of rollouts *per task* (episode_index = 0..n_envs-1).
|
||||||
|
- `task` can be a single difficulty group (e.g., "easy", "medium", "hard") or a comma-separated list.
|
||||||
|
- If a task name is not in DIFFICULTY_TO_TASKS, we treat it as a single custom task.
|
||||||
|
"""
|
||||||
|
if env_cls is None or not callable(env_cls):
|
||||||
|
raise ValueError("env_cls must be a callable that wraps a list of environment factory callables.")
|
||||||
|
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||||
|
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||||
|
|
||||||
|
gym_kwargs = dict(gym_kwargs or {})
|
||||||
|
task_groups = [t.strip() for t in task.split(",") if t.strip()]
|
||||||
|
if not task_groups:
|
||||||
|
raise ValueError("`task` must contain at least one Meta-World task or difficulty group.")
|
||||||
|
|
||||||
|
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
||||||
|
|
||||||
|
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||||
|
|
||||||
|
for group in task_groups:
|
||||||
|
# if not in difficulty presets, treat it as a single custom task
|
||||||
|
tasks = DIFFICULTY_TO_TASKS.get(group, [group])
|
||||||
|
|
||||||
|
for tid, task_name in enumerate(tasks):
|
||||||
|
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
|
||||||
|
|
||||||
|
# build n_envs factories
|
||||||
|
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||||
|
|
||||||
|
out[group][tid] = env_cls(fns)
|
||||||
|
|
||||||
|
# return a plain dict for consistency
|
||||||
|
return {group: dict(task_map) for group, task_map in out.items()}
|
||||||
121
src/lerobot/envs/metaworld_config.json
Normal file
121
src/lerobot/envs/metaworld_config.json
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
{
|
||||||
|
"TASK_DESCRIPTIONS": {
|
||||||
|
"assembly-v3": "Pick up a nut and place it onto a peg",
|
||||||
|
"basketball-v3": "Dunk the basketball into the basket",
|
||||||
|
"bin-picking-v3": "Grasp the puck from one bin and place it into another bin",
|
||||||
|
"box-close-v3": "Grasp the cover and close the box with it",
|
||||||
|
"button-press-topdown-v3": "Press a button from the top",
|
||||||
|
"button-press-topdown-wall-v3": "Bypass a wall and press a button from the top",
|
||||||
|
"button-press-v3": "Press a button",
|
||||||
|
"button-press-wall-v3": "Bypass a wall and press a button",
|
||||||
|
"coffee-button-v3": "Push a button on the coffee machine",
|
||||||
|
"coffee-pull-v3": "Pull a mug from a coffee machine",
|
||||||
|
"coffee-push-v3": "Push a mug under a coffee machine",
|
||||||
|
"dial-turn-v3": "Rotate a dial 180 degrees",
|
||||||
|
"disassemble-v3": "Pick a nut out of a peg",
|
||||||
|
"door-close-v3": "Close a door with a revolving joint",
|
||||||
|
"door-lock-v3": "Lock the door by rotating the lock clockwise",
|
||||||
|
"door-open-v3": "Open a door with a revolving joint",
|
||||||
|
"door-unlock-v3": "Unlock the door by rotating the lock counter-clockwise",
|
||||||
|
"hand-insert-v3": "Insert the gripper into a hole",
|
||||||
|
"drawer-close-v3": "Push and close a drawer",
|
||||||
|
"drawer-open-v3": "Open a drawer",
|
||||||
|
"faucet-open-v3": "Rotate the faucet counter-clockwise",
|
||||||
|
"faucet-close-v3": "Rotate the faucet clockwise",
|
||||||
|
"hammer-v3": "Hammer a screw on the wall",
|
||||||
|
"handle-press-side-v3": "Press a handle down sideways",
|
||||||
|
"handle-press-v3": "Press a handle down",
|
||||||
|
"handle-pull-side-v3": "Pull a handle up sideways",
|
||||||
|
"handle-pull-v3": "Pull a handle up",
|
||||||
|
"lever-pull-v3": "Pull a lever down 90 degrees",
|
||||||
|
"peg-insert-side-v3": "Insert a peg sideways",
|
||||||
|
"pick-place-wall-v3": "Pick a puck, bypass a wall and place the puck",
|
||||||
|
"pick-out-of-hole-v3": "Pick up a puck from a hole",
|
||||||
|
"reach-v3": "Reach a goal position",
|
||||||
|
"push-back-v3": "Push the puck to a goal",
|
||||||
|
"push-v3": "Push the puck to a goal",
|
||||||
|
"pick-place-v3": "Pick and place a puck to a goal",
|
||||||
|
"plate-slide-v3": "Slide a plate into a cabinet",
|
||||||
|
"plate-slide-side-v3": "Slide a plate into a cabinet sideways",
|
||||||
|
"plate-slide-back-v3": "Get a plate from the cabinet",
|
||||||
|
"plate-slide-back-side-v3": "Get a plate from the cabinet sideways",
|
||||||
|
"peg-unplug-side-v3": "Unplug a peg sideways",
|
||||||
|
"soccer-v3": "Kick a soccer into the goal",
|
||||||
|
"stick-push-v3": "Grasp a stick and push a box using the stick",
|
||||||
|
"stick-pull-v3": "Grasp a stick and pull a box with the stick",
|
||||||
|
"push-wall-v3": "Bypass a wall and push a puck to a goal",
|
||||||
|
"reach-wall-v3": "Bypass a wall and reach a goal",
|
||||||
|
"shelf-place-v3": "Pick and place a puck onto a shelf",
|
||||||
|
"sweep-into-v3": "Sweep a puck into a hole",
|
||||||
|
"sweep-v3": "Sweep a puck off the table",
|
||||||
|
"window-open-v3": "Push and open a window",
|
||||||
|
"window-close-v3": "Push and close a window"
|
||||||
|
},
|
||||||
|
"TASK_NAME_TO_ID": {
|
||||||
|
"assembly-v3": 0, "basketball-v3": 1, "bin-picking-v3": 2, "box-close-v3": 3,
|
||||||
|
"button-press-topdown-v3": 4, "button-press-topdown-wall-v3": 5, "button-press-v3": 6,
|
||||||
|
"button-press-wall-v3": 7, "coffee-button-v3": 8, "coffee-pull-v3": 9, "coffee-push-v3": 10,
|
||||||
|
"dial-turn-v3": 11, "disassemble-v3": 12, "door-close-v3": 13, "door-lock-v3": 14,
|
||||||
|
"door-open-v3": 15, "door-unlock-v3": 16, "drawer-close-v3": 17, "drawer-open-v3": 18,
|
||||||
|
"faucet-close-v3": 19, "faucet-open-v3": 20, "hammer-v3": 21, "hand-insert-v3": 22,
|
||||||
|
"handle-press-side-v3": 23, "handle-press-v3": 24, "handle-pull-side-v3": 25,
|
||||||
|
"handle-pull-v3": 26, "lever-pull-v3": 27, "peg-insert-side-v3": 28, "peg-unplug-side-v3": 29,
|
||||||
|
"pick-out-of-hole-v3": 30, "pick-place-v3": 31, "pick-place-wall-v3": 32,
|
||||||
|
"plate-slide-back-side-v3": 33, "plate-slide-back-v3": 34, "plate-slide-side-v3": 35,
|
||||||
|
"plate-slide-v3": 36, "push-back-v3": 37, "push-v3": 38, "push-wall-v3": 39, "reach-v3": 40,
|
||||||
|
"reach-wall-v3": 41, "shelf-place-v3": 42, "soccer-v3": 43, "stick-pull-v3": 44,
|
||||||
|
"stick-push-v3": 45, "sweep-into-v3": 46, "sweep-v3": 47, "window-open-v3": 48,
|
||||||
|
"window-close-v3": 49
|
||||||
|
},
|
||||||
|
"DIFFICULTY_TO_TASKS": {
|
||||||
|
"easy": [
|
||||||
|
"button-press-v3", "button-press-topdown-v3", "button-press-topdown-wall-v3",
|
||||||
|
"button-press-wall-v3", "coffee-button-v3", "dial-turn-v3", "door-close-v3",
|
||||||
|
"door-lock-v3", "door-open-v3", "door-unlock-v3", "drawer-close-v3", "drawer-open-v3",
|
||||||
|
"faucet-close-v3", "faucet-open-v3", "handle-press-v3", "handle-press-side-v3",
|
||||||
|
"handle-pull-v3", "handle-pull-side-v3", "lever-pull-v3", "plate-slide-v3",
|
||||||
|
"plate-slide-back-v3", "plate-slide-back-side-v3", "plate-slide-side-v3", "reach-v3",
|
||||||
|
"reach-wall-v3", "window-close-v3", "window-open-v3", "peg-unplug-side-v3"
|
||||||
|
],
|
||||||
|
"medium": [
|
||||||
|
"basketball-v3", "bin-picking-v3", "box-close-v3", "coffee-pull-v3", "coffee-push-v3",
|
||||||
|
"hammer-v3", "peg-insert-side-v3", "push-wall-v3", "soccer-v3", "sweep-v3", "sweep-into-v3"
|
||||||
|
],
|
||||||
|
"hard": [
|
||||||
|
"assembly-v3", "hand-insert-v3", "pick-out-of-hole-v3", "pick-place-v3", "push-v3", "push-back-v3"
|
||||||
|
],
|
||||||
|
"very_hard": [
|
||||||
|
"shelf-place-v3", "disassemble-v3", "stick-pull-v3", "stick-push-v3", "pick-place-wall-v3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"TASK_POLICY_MAPPING": {
|
||||||
|
"assembly-v3": "SawyerAssemblyV3Policy", "basketball-v3": "SawyerBasketballV3Policy",
|
||||||
|
"bin-picking-v3": "SawyerBinPickingV3Policy", "box-close-v3": "SawyerBoxCloseV3Policy",
|
||||||
|
"button-press-topdown-v3": "SawyerButtonPressTopdownV3Policy",
|
||||||
|
"button-press-topdown-wall-v3": "SawyerButtonPressTopdownWallV3Policy",
|
||||||
|
"button-press-v3": "SawyerButtonPressV3Policy", "button-press-wall-v3": "SawyerButtonPressWallV3Policy",
|
||||||
|
"coffee-button-v3": "SawyerCoffeeButtonV3Policy", "coffee-pull-v3": "SawyerCoffeePullV3Policy",
|
||||||
|
"coffee-push-v3": "SawyerCoffeePushV3Policy", "dial-turn-v3": "SawyerDialTurnV3Policy",
|
||||||
|
"disassemble-v3": "SawyerDisassembleV3Policy", "door-close-v3": "SawyerDoorCloseV3Policy",
|
||||||
|
"door-lock-v3": "SawyerDoorLockV3Policy", "door-open-v3": "SawyerDoorOpenV3Policy",
|
||||||
|
"door-unlock-v3": "SawyerDoorUnlockV3Policy", "drawer-close-v3": "SawyerDrawerCloseV3Policy",
|
||||||
|
"drawer-open-v3": "SawyerDrawerOpenV3Policy", "faucet-close-v3": "SawyerFaucetCloseV3Policy",
|
||||||
|
"faucet-open-v3": "SawyerFaucetOpenV3Policy", "hammer-v3": "SawyerHammerV3Policy",
|
||||||
|
"hand-insert-v3": "SawyerHandInsertV3Policy", "handle-press-side-v3": "SawyerHandlePressSideV3Policy",
|
||||||
|
"handle-press-v3": "SawyerHandlePressV3Policy", "handle-pull-side-v3": "SawyerHandlePullSideV3Policy",
|
||||||
|
"handle-pull-v3": "SawyerHandlePullV3Policy", "lever-pull-v3": "SawyerLeverPullV3Policy",
|
||||||
|
"peg-insert-side-v3": "SawyerPegInsertionSideV3Policy", "peg-unplug-side-v3": "SawyerPegUnplugSideV3Policy",
|
||||||
|
"pick-out-of-hole-v3": "SawyerPickOutOfHoleV3Policy", "pick-place-v3": "SawyerPickPlaceV3Policy",
|
||||||
|
"pick-place-wall-v3": "SawyerPickPlaceWallV3Policy",
|
||||||
|
"plate-slide-back-side-v3": "SawyerPlateSlideBackSideV3Policy",
|
||||||
|
"plate-slide-back-v3": "SawyerPlateSlideBackV3Policy",
|
||||||
|
"plate-slide-side-v3": "SawyerPlateSlideSideV3Policy", "plate-slide-v3": "SawyerPlateSlideV3Policy",
|
||||||
|
"push-back-v3": "SawyerPushBackV3Policy", "push-v3": "SawyerPushV3Policy",
|
||||||
|
"push-wall-v3": "SawyerPushWallV3Policy", "reach-v3": "SawyerReachV3Policy",
|
||||||
|
"reach-wall-v3": "SawyerReachWallV3Policy", "shelf-place-v3": "SawyerShelfPlaceV3Policy",
|
||||||
|
"soccer-v3": "SawyerSoccerV3Policy", "stick-pull-v3": "SawyerStickPullV3Policy",
|
||||||
|
"stick-push-v3": "SawyerStickPushV3Policy", "sweep-into-v3": "SawyerSweepIntoV3Policy",
|
||||||
|
"sweep-v3": "SawyerSweepV3Policy", "window-open-v3": "SawyerWindowOpenV3Policy",
|
||||||
|
"window-close-v3": "SawyerWindowCloseV3Policy"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -180,9 +180,15 @@ def rollout(
|
|||||||
render_callback(env)
|
render_callback(env)
|
||||||
|
|
||||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||||
# available of none of the envs finished.
|
# available if none of the envs finished.
|
||||||
if "final_info" in info:
|
if "final_info" in info:
|
||||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
final_info = info["final_info"]
|
||||||
|
if not isinstance(final_info, dict):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
|
||||||
|
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
|
||||||
|
)
|
||||||
|
successes = final_info["is_success"].tolist()
|
||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,6 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
|
"ds_repo_id,env_name,env_kwargs,policy_name,policy_kwargs",
|
||||||
[
|
[
|
||||||
("lerobot/xarm_lift_medium", "xarm", {}, "tdmpc", {"use_mpc": True}),
|
|
||||||
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
("lerobot/pusht", "pusht", {}, "diffusion", {}),
|
||||||
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
("lerobot/pusht", "pusht", {}, "vqbet", {}),
|
||||||
("lerobot/pusht", "pusht", {}, "act", {}),
|
("lerobot/pusht", "pusht", {}, "act", {}),
|
||||||
@@ -328,8 +327,6 @@ def test_multikey_construction(multikey: bool):
|
|||||||
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
# TODO(alexander-soare): `policy.use_mpc=false` was previously the default in the config yaml but it
|
||||||
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
# was changed to true. For some reason, tests would pass locally, but not in CI. So here we override
|
||||||
# to test with `policy.use_mpc=false`.
|
# to test with `policy.use_mpc=false`.
|
||||||
("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": False}, "use_policy"),
|
|
||||||
# ("lerobot/xarm_lift_medium", "tdmpc", {"use_mpc": True}, "use_mpc"),
|
|
||||||
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
# TODO(rcadene): the diffusion model was normalizing the image in mean=0.5 std=0.5 which is a hack supposed to
|
||||||
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
# to normalize the image at all. In our current codebase we dont normalize at all. But there is still a minor difference
|
||||||
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
# that fails the test. However, by testing to normalize the image with 0.5 0.5 in the current codebase, the test pass.
|
||||||
|
|||||||
Reference in New Issue
Block a user