diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index 8b06cc4c..838f48bc 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -277,3 +277,56 @@ class ManiskillEnvConfig(EnvConfig): "sensor_configs": {"width": self.image_size, "height": self.image_size}, "num_envs": 1, } + + +@EnvConfig.register_subclass("hil") +@dataclass +class HILEnvConfig(EnvConfig): + """Configuration for the HIL environment.""" + + type: str = "hil" + name: str = "PandaPickCube" + task: str = "PandaPickCubeKeyboard-v0" + use_viewer: bool = True + gripper_penalty: float = 0.0 + use_gamepad: bool = True + state_dim: int = 18 + action_dim: int = 4 + fps: int = 100 + episode_length: int = 100 + video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "observation.image": OBS_IMAGE, + "observation.state": OBS_ROBOT, + } + ) + ################# args from hilserlrobotenv + reward_classifier_pretrained_path: Optional[str] = None + robot: Optional[RobotConfig] = None + wrapper: Optional[EnvWrapperConfig] = None + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + ############################ + + @property + def gym_kwargs(self) -> dict: + return { + "use_viewer": self.use_viewer, + "use_gamepad": self.use_gamepad, + "gripper_penalty": self.gripper_penalty, + } diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 8450f84b..e6f91ce8 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv +from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) + elif env_type == "hil": + return HILEnvConfig(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") @@ -65,5 +67,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g env = env_cls( [lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)] ) + # TODO: add observation processor wrapper and remove preprocess_observation in the codebase + # env = ObservationProcessorWrapper(env=env) return env diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 83334f87..78517abe 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -47,6 +47,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO(aliberts, rcadene): use transforms.ToTensor()? img = torch.from_numpy(img) + if img.dim() == 3: + img = img.unsqueeze(0) # sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" @@ -62,16 +64,50 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations[imgkey] = img if "environment_state" in observations: - return_observations["observation.environment_state"] = torch.from_numpy( - observations["environment_state"] - ).float() + env_state = torch.from_numpy(observations["environment_state"]).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + + return_observations["observation.environment_state"] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing - # requirement for "agent_pos" - return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + agent_pos = torch.from_numpy(observations["agent_pos"]).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + return_observations["observation.state"] = agent_pos + return return_observations +class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper): + def __init__(self, env: gym.vector.VectorEnv): + super().__init__(env) + + def _observations(self, observations: dict[str, Any]) -> dict[str, Any]: + return preprocess_observation(observations) + + def reset( + self, + *, + seed: int | list[int] | None = None, + options: dict[str, Any] | None = None, + ): + """Modifies the observation returned from the environment ``reset`` using the :meth:`observation`.""" + observations, infos = self.env.reset(seed=seed, options=options) + return self._observations(observations), infos + + def step(self, actions): + """Modifies the observation returned from the environment ``step`` using the :meth:`observation`.""" + observations, rewards, terminations, truncations, infos = self.env.step(actions) + return ( + self._observations(observations), + rewards, + terminations, + truncations, + infos, + ) + + def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is # (need to also refactor preprocess_observation and externalize normalization from policies) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py index 48ba91e3..02659e69 100644 --- a/lerobot/scripts/server/gym_manipulator.py +++ b/lerobot/scripts/server/gym_manipulator.py @@ -11,6 +11,7 @@ import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import preprocess_observation from lerobot.common.robot_devices.control_utils import ( busy_wait, is_headless, @@ -1713,6 +1714,50 @@ class GamepadControlWrapper(gym.Wrapper): return self.env.close() +class GymHilDeviceWrapper(gym.Wrapper): + def __init__(self, env, device="cpu"): + super().__init__(env) + self.device = device + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, reward, terminated, truncated, info + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + obs, info = self.env.reset(seed=seed, options=options) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, info + + +class GymHilObservationProcessorWrapper(gym.ObservationWrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + prev_space = self.observation_space + new_space = {} + + for key in prev_space: + if "pixels" in key: + for k in prev_space["pixels"]: + new_space[f"observation.images.{k}"] = gym.spaces.Box( + 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8 + ) + + if key == "agent_pos": + new_space["observation.state"] = prev_space["agent_pos"] + + self.observation_space = gym.spaces.Dict(new_space) + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + return preprocess_observation(observation) + + ########################################################### # Factory functions ########################################################### @@ -1729,8 +1774,27 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv: cfg: Configuration object containing environment parameters. Returns: + A vectorized gym environment with all necessary wrappers applied. """ + if cfg.type == "hil": + import gymnasium as gym + + # TODO (azouitine) + env = gym.make( + f"gym_hil/{cfg.task}", + image_obs=True, + render_mode="human", + step_size=cfg.wrapper.ee_action_space_params.x_step_size, + use_gripper=cfg.wrapper.use_gripper, + gripper_penalty=cfg.wrapper.gripper_penalty, + ) + env = GymHilObservationProcessorWrapper(env=env) + env = GymHilDeviceWrapper(env=env, device=cfg.device) + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + return env + robot = make_robot_from_config(cfg.robot) # Create base environment env = RobotEnv( @@ -1883,6 +1947,11 @@ def record_dataset(env, policy, cfg, success_collection_steps=15): }, "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + "complementary_info.discrete_penalty": { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], + }, } # Add image features @@ -1962,6 +2031,9 @@ def record_dataset(env, policy, cfg, success_collection_steps=15): frame["next.done"] = np.array([really_done], dtype=bool) frame["task"] = cfg.task + frame["complementary_info.discrete_penalty"] = torch.tensor( + [info.get("discrete_penalty", 0.0)], dtype=torch.float32 + ) dataset.add_frame(frame) # Maintain consistent timing @@ -2074,7 +2146,7 @@ def main(cfg: EnvConfig): num_episode = 0 successes = [] - while num_episode < 20: + while num_episode < 10: start_loop_s = time.perf_counter() # Sample a new random action from the robot's action space. new_random_action = env.action_space.sample() diff --git a/pyproject.toml b/pyproject.toml index d854d73f..b3c4b7ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dora = [ ] dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] -hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"] +hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "gym-hil>=0.1.2"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]