[HIL SERL] Env management and add gym-hil (#1077)

Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
This commit is contained in:
Adil Zouitine
2025-05-07 09:39:21 +02:00
committed by AdilZouitine
parent e76f29ff7a
commit 049773a5fa
5 changed files with 173 additions and 8 deletions

View File

@@ -277,3 +277,56 @@ class ManiskillEnvConfig(EnvConfig):
"sensor_configs": {"width": self.image_size, "height": self.image_size},
"num_envs": 1,
}
@EnvConfig.register_subclass("hil")
@dataclass
class HILEnvConfig(EnvConfig):
"""Configuration for the HIL environment."""
type: str = "hil"
name: str = "PandaPickCube"
task: str = "PandaPickCubeKeyboard-v0"
use_viewer: bool = True
gripper_penalty: float = 0.0
use_gamepad: bool = True
state_dim: int = 18
action_dim: int = 4
fps: int = 100
episode_length: int = 100
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
features: dict[str, PolicyFeature] = field(
default_factory=lambda: {
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
}
)
features_map: dict[str, str] = field(
default_factory=lambda: {
"action": ACTION,
"observation.image": OBS_IMAGE,
"observation.state": OBS_ROBOT,
}
)
################# args from hilserlrobotenv
reward_classifier_pretrained_path: Optional[str] = None
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
############################
@property
def gym_kwargs(self) -> dict:
return {
"use_viewer": self.use_viewer,
"use_gamepad": self.use_gamepad,
"gripper_penalty": self.gripper_penalty,
}

View File

@@ -17,7 +17,7 @@ import importlib
import gymnasium as gym
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
return PushtEnv(**kwargs)
elif env_type == "xarm":
return XarmEnv(**kwargs)
elif env_type == "hil":
return HILEnvConfig(**kwargs)
else:
raise ValueError(f"Policy type '{env_type}' is not available.")
@@ -65,5 +67,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
env = env_cls(
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
)
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
# env = ObservationProcessorWrapper(env=env)
return env

View File

@@ -47,6 +47,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# TODO(aliberts, rcadene): use transforms.ToTensor()?
img = torch.from_numpy(img)
if img.dim() == 3:
img = img.unsqueeze(0)
# sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
@@ -62,16 +64,50 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
return_observations[imgkey] = img
if "environment_state" in observations:
return_observations["observation.environment_state"] = torch.from_numpy(
observations["environment_state"]
).float()
env_state = torch.from_numpy(observations["environment_state"]).float()
if env_state.dim() == 1:
env_state = env_state.unsqueeze(0)
return_observations["observation.environment_state"] = env_state
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
# requirement for "agent_pos"
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
if agent_pos.dim() == 1:
agent_pos = agent_pos.unsqueeze(0)
return_observations["observation.state"] = agent_pos
return return_observations
class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper):
def __init__(self, env: gym.vector.VectorEnv):
super().__init__(env)
def _observations(self, observations: dict[str, Any]) -> dict[str, Any]:
return preprocess_observation(observations)
def reset(
self,
*,
seed: int | list[int] | None = None,
options: dict[str, Any] | None = None,
):
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
observations, infos = self.env.reset(seed=seed, options=options)
return self._observations(observations), infos
def step(self, actions):
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
observations, rewards, terminations, truncations, infos = self.env.step(actions)
return (
self._observations(observations),
rewards,
terminations,
truncations,
infos,
)
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
# (need to also refactor preprocess_observation and externalize normalization from policies)