forked from tangger/lerobot
Simplify configs (#550)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
1
lerobot/common/envs/__init__.py
Normal file
1
lerobot/common/envs/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv # noqa: F401
|
||||
142
lerobot/common/envs/configs.py
Normal file
142
lerobot/common/envs/configs.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
task: str | None = None
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,))
|
||||
self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)),
|
||||
"agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"environment_state": OBS_ENV,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3))
|
||||
elif self.obs_type == "environment_state_agent_pos":
|
||||
self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
visualization_width: int = 384
|
||||
visualization_height: int = 384
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||
"pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"visualization_width": self.visualization_width,
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
@@ -16,43 +16,54 @@
|
||||
import importlib
|
||||
|
||||
import gymnasium as gym
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the evaluation config.
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
if env_type == "aloha":
|
||||
return AlohaEnv(**kwargs)
|
||||
elif env_type == "pusht":
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "xarm":
|
||||
return XarmEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
n_envs can be used to override eval.batch_size in the configuration. Must be at least 1.
|
||||
|
||||
def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> gym.vector.VectorEnv | None:
|
||||
"""Makes a gym vector environment according to the config.
|
||||
|
||||
Args:
|
||||
cfg (EnvConfig): the config of the environment to instantiate.
|
||||
n_envs (int, optional): The number of parallelized env to return. Defaults to 1.
|
||||
use_async_envs (bool, optional): Wether to return an AsyncVectorEnv or a SyncVectorEnv. Defaults to
|
||||
False.
|
||||
|
||||
Raises:
|
||||
ValueError: if n_envs < 1
|
||||
ModuleNotFoundError: If the requested env package is not intalled
|
||||
|
||||
Returns:
|
||||
gym.vector.VectorEnv: The parallelized gym.env instance.
|
||||
"""
|
||||
if n_envs is not None and n_envs < 1:
|
||||
if n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
package_name = f"gym_{cfg.type}"
|
||||
|
||||
try:
|
||||
importlib.import_module(package_name)
|
||||
except ModuleNotFoundError as e:
|
||||
print(
|
||||
f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
|
||||
)
|
||||
print(f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.type}]'`")
|
||||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
gym_handle = f"{package_name}/{cfg.task}"
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
@@ -18,8 +18,13 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.utils.utils import get_channel_first_image_shape
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -35,6 +40,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
@@ -60,3 +66,23 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||
policy_features = {}
|
||||
for key, ft in env_cfg.features.items():
|
||||
if ft.type is FeatureType.VISUAL:
|
||||
if len(ft.shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={ft.shape})")
|
||||
|
||||
shape = get_channel_first_image_shape(ft.shape)
|
||||
feature = PolicyFeature(type=ft.type, shape=shape)
|
||||
else:
|
||||
feature = ft
|
||||
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
|
||||
Reference in New Issue
Block a user