[HIL SERL] Env management and add gym-hil (#1077)
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
This commit is contained in:
@@ -276,3 +276,56 @@ class ManiskillEnvConfig(EnvConfig):
|
|||||||
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
||||||
"num_envs": 1,
|
"num_envs": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@EnvConfig.register_subclass("hil")
|
||||||
|
@dataclass
|
||||||
|
class HILEnvConfig(EnvConfig):
|
||||||
|
"""Configuration for the HIL environment."""
|
||||||
|
|
||||||
|
type: str = "hil"
|
||||||
|
name: str = "PandaPickCube"
|
||||||
|
task: str = "PandaPickCubeKeyboard-v0"
|
||||||
|
use_viewer: bool = True
|
||||||
|
gripper_penalty: float = 0.0
|
||||||
|
use_gamepad: bool = True
|
||||||
|
state_dim: int = 18
|
||||||
|
action_dim: int = 4
|
||||||
|
fps: int = 100
|
||||||
|
episode_length: int = 100
|
||||||
|
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||||
|
features: dict[str, PolicyFeature] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)),
|
||||||
|
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)),
|
||||||
|
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
features_map: dict[str, str] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"action": ACTION,
|
||||||
|
"observation.image": OBS_IMAGE,
|
||||||
|
"observation.state": OBS_ROBOT,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
################# args from hilserlrobotenv
|
||||||
|
reward_classifier_pretrained_path: Optional[str] = None
|
||||||
|
robot: Optional[RobotConfig] = None
|
||||||
|
wrapper: Optional[EnvWrapperConfig] = None
|
||||||
|
mode: str = None # Either "record", "replay", None
|
||||||
|
repo_id: Optional[str] = None
|
||||||
|
dataset_root: Optional[str] = None
|
||||||
|
num_episodes: int = 10 # only for record mode
|
||||||
|
episode: int = 0
|
||||||
|
device: str = "cuda"
|
||||||
|
push_to_hub: bool = True
|
||||||
|
pretrained_policy_name_or_path: Optional[str] = None
|
||||||
|
############################
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gym_kwargs(self) -> dict:
|
||||||
|
return {
|
||||||
|
"use_viewer": self.use_viewer,
|
||||||
|
"use_gamepad": self.use_gamepad,
|
||||||
|
"gripper_penalty": self.gripper_penalty,
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import importlib
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv
|
from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv
|
||||||
|
|
||||||
|
|
||||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||||
@@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
|||||||
return PushtEnv(**kwargs)
|
return PushtEnv(**kwargs)
|
||||||
elif env_type == "xarm":
|
elif env_type == "xarm":
|
||||||
return XarmEnv(**kwargs)
|
return XarmEnv(**kwargs)
|
||||||
|
elif env_type == "hil":
|
||||||
|
return HILEnvConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||||
|
|
||||||
@@ -65,5 +67,7 @@ def make_env(cfg: EnvConfig, n_envs: int = 1, use_async_envs: bool = False) -> g
|
|||||||
env = env_cls(
|
env = env_cls(
|
||||||
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
[lambda: gym.make(gym_handle, disable_env_checker=True, **cfg.gym_kwargs) for _ in range(n_envs)]
|
||||||
)
|
)
|
||||||
|
# TODO: add observation processor wrapper and remove preprocess_observation in the codebase
|
||||||
|
# env = ObservationProcessorWrapper(env=env)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||||
img = torch.from_numpy(img)
|
img = torch.from_numpy(img)
|
||||||
|
|
||||||
|
if img.dim() == 3:
|
||||||
|
img = img.unsqueeze(0)
|
||||||
# sanity check that images are channel last
|
# sanity check that images are channel last
|
||||||
_, h, w, c = img.shape
|
_, h, w, c = img.shape
|
||||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||||
@@ -62,16 +64,50 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
|||||||
return_observations[imgkey] = img
|
return_observations[imgkey] = img
|
||||||
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
env_state = torch.from_numpy(observations["environment_state"]).float()
|
||||||
observations["environment_state"]
|
if env_state.dim() == 1:
|
||||||
).float()
|
env_state = env_state.unsqueeze(0)
|
||||||
|
|
||||||
|
return_observations["observation.environment_state"] = env_state
|
||||||
|
|
||||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||||
# requirement for "agent_pos"
|
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
if agent_pos.dim() == 1:
|
||||||
|
agent_pos = agent_pos.unsqueeze(0)
|
||||||
|
return_observations["observation.state"] = agent_pos
|
||||||
|
|
||||||
return return_observations
|
return return_observations
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationProcessorWrapper(gym.vector.VectorEnvWrapper):
|
||||||
|
def __init__(self, env: gym.vector.VectorEnv):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def _observations(self, observations: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return preprocess_observation(observations)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | list[int] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
):
|
||||||
|
"""Modifies the observation returned from the environment ``reset`` using the :meth:`observation`."""
|
||||||
|
observations, infos = self.env.reset(seed=seed, options=options)
|
||||||
|
return self._observations(observations), infos
|
||||||
|
|
||||||
|
def step(self, actions):
|
||||||
|
"""Modifies the observation returned from the environment ``step`` using the :meth:`observation`."""
|
||||||
|
observations, rewards, terminations, truncations, infos = self.env.step(actions)
|
||||||
|
return (
|
||||||
|
self._observations(observations),
|
||||||
|
rewards,
|
||||||
|
terminations,
|
||||||
|
truncations,
|
||||||
|
infos,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||||
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
# TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is
|
||||||
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
# (need to also refactor preprocess_observation and externalize normalization from policies)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import torch
|
|||||||
import torchvision.transforms.functional as F # noqa: N812
|
import torchvision.transforms.functional as F # noqa: N812
|
||||||
|
|
||||||
from lerobot.common.envs.configs import EnvConfig
|
from lerobot.common.envs.configs import EnvConfig
|
||||||
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
busy_wait,
|
busy_wait,
|
||||||
is_headless,
|
is_headless,
|
||||||
@@ -1713,6 +1714,50 @@ class GamepadControlWrapper(gym.Wrapper):
|
|||||||
return self.env.close()
|
return self.env.close()
|
||||||
|
|
||||||
|
|
||||||
|
class GymHilDeviceWrapper(gym.Wrapper):
|
||||||
|
def __init__(self, env, device="cpu"):
|
||||||
|
super().__init__(env)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
for k in obs:
|
||||||
|
obs[k] = obs[k].to(self.device)
|
||||||
|
if "action_intervention" in info:
|
||||||
|
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
|
||||||
|
obs, info = self.env.reset(seed=seed, options=options)
|
||||||
|
for k in obs:
|
||||||
|
obs[k] = obs[k].to(self.device)
|
||||||
|
if "action_intervention" in info:
|
||||||
|
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
|
||||||
|
class GymHilObservationProcessorWrapper(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
super().__init__(env)
|
||||||
|
prev_space = self.observation_space
|
||||||
|
new_space = {}
|
||||||
|
|
||||||
|
for key in prev_space:
|
||||||
|
if "pixels" in key:
|
||||||
|
for k in prev_space["pixels"]:
|
||||||
|
new_space[f"observation.images.{k}"] = gym.spaces.Box(
|
||||||
|
0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8
|
||||||
|
)
|
||||||
|
|
||||||
|
if key == "agent_pos":
|
||||||
|
new_space["observation.state"] = prev_space["agent_pos"]
|
||||||
|
|
||||||
|
self.observation_space = gym.spaces.Dict(new_space)
|
||||||
|
|
||||||
|
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return preprocess_observation(observation)
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
# Factory functions
|
# Factory functions
|
||||||
###########################################################
|
###########################################################
|
||||||
@@ -1729,8 +1774,27 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||||||
cfg: Configuration object containing environment parameters.
|
cfg: Configuration object containing environment parameters.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
A vectorized gym environment with all necessary wrappers applied.
|
A vectorized gym environment with all necessary wrappers applied.
|
||||||
"""
|
"""
|
||||||
|
if cfg.type == "hil":
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
# TODO (azouitine)
|
||||||
|
env = gym.make(
|
||||||
|
f"gym_hil/{cfg.task}",
|
||||||
|
image_obs=True,
|
||||||
|
render_mode="human",
|
||||||
|
step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
||||||
|
use_gripper=cfg.wrapper.use_gripper,
|
||||||
|
gripper_penalty=cfg.wrapper.gripper_penalty,
|
||||||
|
)
|
||||||
|
env = GymHilObservationProcessorWrapper(env=env)
|
||||||
|
env = GymHilDeviceWrapper(env=env, device=cfg.device)
|
||||||
|
env = BatchCompatibleWrapper(env=env)
|
||||||
|
env = TorchActionWrapper(env=env, device=cfg.device)
|
||||||
|
return env
|
||||||
|
|
||||||
robot = make_robot_from_config(cfg.robot)
|
robot = make_robot_from_config(cfg.robot)
|
||||||
# Create base environment
|
# Create base environment
|
||||||
env = RobotEnv(
|
env = RobotEnv(
|
||||||
@@ -1883,6 +1947,11 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
|||||||
},
|
},
|
||||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||||
|
"complementary_info.discrete_penalty": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": ["discrete_penalty"],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add image features
|
# Add image features
|
||||||
@@ -1962,6 +2031,9 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
|||||||
|
|
||||||
frame["next.done"] = np.array([really_done], dtype=bool)
|
frame["next.done"] = np.array([really_done], dtype=bool)
|
||||||
frame["task"] = cfg.task
|
frame["task"] = cfg.task
|
||||||
|
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||||
|
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||||
|
)
|
||||||
dataset.add_frame(frame)
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
# Maintain consistent timing
|
# Maintain consistent timing
|
||||||
@@ -2074,7 +2146,7 @@ def main(cfg: EnvConfig):
|
|||||||
|
|
||||||
num_episode = 0
|
num_episode = 0
|
||||||
successes = []
|
successes = []
|
||||||
while num_episode < 20:
|
while num_episode < 10:
|
||||||
start_loop_s = time.perf_counter()
|
start_loop_s = time.perf_counter()
|
||||||
# Sample a new random action from the robot's action space.
|
# Sample a new random action from the robot's action space.
|
||||||
new_random_action = env.action_space.sample()
|
new_random_action = env.action_space.sample()
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ dora = [
|
|||||||
]
|
]
|
||||||
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||||
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0"]
|
hilserl = ["transformers>=4.48.0", "torchmetrics>=1.6.0", "gym-hil>=0.1.2"]
|
||||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||||
pi0 = ["transformers>=4.48.0"]
|
pi0 = ["transformers>=4.48.0"]
|
||||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||||
|
|||||||
Reference in New Issue
Block a user