import abc from dataclasses import dataclass, field import draccus from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT from lerobot.configs.types import FeatureType, PolicyFeature @dataclass class EnvConfig(draccus.ChoiceRegistry, abc.ABC): task: str | None = None fps: int = 30 features: dict[str, PolicyFeature] = field(default_factory=dict) features_map: dict[str, str] = field(default_factory=dict) @property def type(self) -> str: return self.get_choice_name(self.__class__) @abc.abstractproperty def gym_kwargs(self) -> dict: raise NotImplementedError() @EnvConfig.register_subclass("aloha") @dataclass class AlohaEnv(EnvConfig): task: str = "AlohaInsertion-v0" fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" render_mode: str = "rgb_array" features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(14,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, "agent_pos": OBS_ROBOT, "top": f"{OBS_IMAGE}.top", "pixels/top": f"{OBS_IMAGES}.top", } ) def __post_init__(self): if self.obs_type == "pixels": self.features["top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) elif self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(14,)) self.features["pixels/top"] = PolicyFeature(type=FeatureType.VISUAL, shape=(480, 640, 3)) @property def gym_kwargs(self) -> dict: return { "obs_type": self.obs_type, "render_mode": self.render_mode, "max_episode_steps": self.episode_length, } @EnvConfig.register_subclass("pusht") @dataclass class PushtEnv(EnvConfig): task: str = "PushT-v0" fps: int = 10 episode_length: int = 300 obs_type: str = "pixels_agent_pos" render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(2,)), "agent_pos": PolicyFeature(type=FeatureType.STATE, shape=(2,)), } ) features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, "agent_pos": OBS_ROBOT, "environment_state": OBS_ENV, "pixels": OBS_IMAGE, } ) def __post_init__(self): if self.obs_type == "pixels_agent_pos": self.features["pixels"] = PolicyFeature(type=FeatureType.VISUAL, shape=(384, 384, 3)) elif self.obs_type == "environment_state_agent_pos": self.features["environment_state"] = PolicyFeature(type=FeatureType.ENV, shape=(16,)) @property def gym_kwargs(self) -> dict: return { "obs_type": self.obs_type, "render_mode": self.render_mode, "visualization_width": self.visualization_width, "visualization_height": self.visualization_height, "max_episode_steps": self.episode_length, } @EnvConfig.register_subclass("xarm") @dataclass class XarmEnv(EnvConfig): task: str = "XarmLift-v0" fps: int = 15 episode_length: int = 200 obs_type: str = "pixels_agent_pos" render_mode: str = "rgb_array" visualization_width: int = 384 visualization_height: int = 384 features: dict[str, PolicyFeature] = field( default_factory=lambda: { "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), "pixels": PolicyFeature(type=FeatureType.VISUAL, shape=(84, 84, 3)), } ) features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, "agent_pos": OBS_ROBOT, "pixels": OBS_IMAGE, } ) def __post_init__(self): if self.obs_type == "pixels_agent_pos": self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(4,)) @property def gym_kwargs(self) -> dict: return { "obs_type": self.obs_type, "render_mode": self.render_mode, "visualization_width": self.visualization_width, "visualization_height": self.visualization_height, "max_episode_steps": self.episode_length, }