[HIL SERL] Env management and add gym-hil (#1077)
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
This commit is contained in:
committed by
AdilZouitine
parent
e76f29ff7a
commit
049773a5fa
@@ -11,6 +11,7 @@ import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
@@ -1713,6 +1714,50 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
return self.env.close()
|
||||
|
||||
|
||||
class GymHilDeviceWrapper(gym.Wrapper):
|
||||
def __init__(self, env, device="cpu"):
|
||||
super().__init__(env)
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
for k in obs:
|
||||
obs[k] = obs[k].to(self.device)
|
||||
if "action_intervention" in info:
|
||||
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None):
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
for k in obs:
|
||||
obs[k] = obs[k].to(self.device)
|
||||
if "action_intervention" in info:
|
||||
info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device)
|
||||
return obs, info
|
||||
|
||||
|
||||
class GymHilObservationProcessorWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env: gym.Env):
|
||||
super().__init__(env)
|
||||
prev_space = self.observation_space
|
||||
new_space = {}
|
||||
|
||||
for key in prev_space:
|
||||
if "pixels" in key:
|
||||
for k in prev_space["pixels"]:
|
||||
new_space[f"observation.images.{k}"] = gym.spaces.Box(
|
||||
0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8
|
||||
)
|
||||
|
||||
if key == "agent_pos":
|
||||
new_space["observation.state"] = prev_space["agent_pos"]
|
||||
|
||||
self.observation_space = gym.spaces.Dict(new_space)
|
||||
|
||||
def observation(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
return preprocess_observation(observation)
|
||||
|
||||
|
||||
###########################################################
|
||||
# Factory functions
|
||||
###########################################################
|
||||
@@ -1729,8 +1774,27 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
cfg: Configuration object containing environment parameters.
|
||||
|
||||
Returns:
|
||||
|
||||
A vectorized gym environment with all necessary wrappers applied.
|
||||
"""
|
||||
if cfg.type == "hil":
|
||||
import gymnasium as gym
|
||||
|
||||
# TODO (azouitine)
|
||||
env = gym.make(
|
||||
f"gym_hil/{cfg.task}",
|
||||
image_obs=True,
|
||||
render_mode="human",
|
||||
step_size=cfg.wrapper.ee_action_space_params.x_step_size,
|
||||
use_gripper=cfg.wrapper.use_gripper,
|
||||
gripper_penalty=cfg.wrapper.gripper_penalty,
|
||||
)
|
||||
env = GymHilObservationProcessorWrapper(env=env)
|
||||
env = GymHilDeviceWrapper(env=env, device=cfg.device)
|
||||
env = BatchCompatibleWrapper(env=env)
|
||||
env = TorchActionWrapper(env=env, device=cfg.device)
|
||||
return env
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
# Create base environment
|
||||
env = RobotEnv(
|
||||
@@ -1883,6 +1947,11 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
||||
},
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"complementary_info.discrete_penalty": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": ["discrete_penalty"],
|
||||
},
|
||||
}
|
||||
|
||||
# Add image features
|
||||
@@ -1962,6 +2031,9 @@ def record_dataset(env, policy, cfg, success_collection_steps=15):
|
||||
|
||||
frame["next.done"] = np.array([really_done], dtype=bool)
|
||||
frame["task"] = cfg.task
|
||||
frame["complementary_info.discrete_penalty"] = torch.tensor(
|
||||
[info.get("discrete_penalty", 0.0)], dtype=torch.float32
|
||||
)
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Maintain consistent timing
|
||||
@@ -2074,7 +2146,7 @@ def main(cfg: EnvConfig):
|
||||
|
||||
num_episode = 0
|
||||
successes = []
|
||||
while num_episode < 20:
|
||||
while num_episode < 10:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
new_random_action = env.action_space.sample()
|
||||
|
||||
Reference in New Issue
Block a user