Added gripper control mechanism to gym_manipulator

Moved HilSerl env config to configs/env/configs.py
fixes in actor_server and modeling_sac and configuration_sac
added the possibility of ignoring missing keys in env_cfg in get_features_from_env_config function
This commit is contained in:
Michel Aractingi
2025-03-28 08:21:36 +01:00
parent 88cc2b8fc8
commit 05a237ce10
7 changed files with 179 additions and 130 deletions

View File

@@ -28,7 +28,6 @@ from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
@@ -268,7 +267,7 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
def act_with_policy(
cfg: TrainPipelineConfig,
robot: Robot,
# robot: Robot,
reward_classifier: nn.Module,
shutdown_event: any, # Event,
parameters_queue: Queue,
@@ -287,7 +286,7 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env( cfg=cfg.env)
online_env = make_robot_env(cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
@@ -503,7 +502,6 @@ def actor_cli(cfg: TrainPipelineConfig):
mp.set_start_method("spawn")
init_logging(log_file="actor.log")
robot = make_robot(robot_type=cfg.env.robot)
shutdown_event = setup_process_handlers(use_threads(cfg))
@@ -563,18 +561,17 @@ def actor_cli(cfg: TrainPipelineConfig):
# HACK: FOR MANISKILL we do not have a reward classifier
# TODO: Remove this once we merge into main
reward_classifier = None
if (
cfg.env.reward_classifier["pretrained_path"] is not None
and cfg.env.reward_classifier["config_path"] is not None
):
reward_classifier = get_classifier(
pretrained_path=cfg.env.reward_classifier["pretrained_path"],
config_path=cfg.env.reward_classifier["config_path"],
)
# if (
# cfg.env.reward_classifier["pretrained_path"] is not None
# and cfg.env.reward_classifier["config_path"] is not None
# ):
# reward_classifier = get_classifier(
# pretrained_path=cfg.env.reward_classifier["pretrained_path"],
# config_path=cfg.env.reward_classifier["config_path"],
# )
act_with_policy(
cfg=cfg,
robot=robot,
reward_classifier=reward_classifier,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,

View File

@@ -29,6 +29,9 @@ class InputController:
self.z_step_size = z_step_size
self.running = True
self.episode_end_status = None # None, "success", or "failure"
self.intervention_flag = False
self.open_gripper_command = False
self.close_gripper_command = False
def start(self):
"""Start the controller and initialize resources."""
@@ -70,6 +73,19 @@ class InputController:
self.episode_end_status = None # Reset after reading
return status
def should_intervene(self):
"""Return True if intervention flag was set."""
return self.intervention_flag
def gripper_command(self):
"""Return the current gripper command."""
if self.open_gripper_command == self.close_gripper_command:
return "no-op"
elif self.open_gripper_command:
return "open"
elif self.close_gripper_command:
return "close"
class KeyboardController(InputController):
"""Generate motion deltas from keyboard input."""
@@ -326,7 +342,6 @@ class GamepadControllerHID(InputController):
self.buttons = {}
self.quit_requested = False
self.save_requested = False
self.intervention_flag = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
@@ -416,7 +431,13 @@ class GamepadControllerHID(InputController):
buttons = data[5]
# Check if RB is pressed then the intervention flag should be set
self.intervention_flag = data[6] == 2
self.intervention_flag = data[6] in [2, 6, 10, 14]
# Check if RT is pressed
self.open_gripper_command = data[6] in [8, 10, 12]
# Check if LT is pressed
self.close_gripper_command = data[6] in [4, 6, 12]
# Check if Y/Triangle button (bit 7) is pressed for saving
# Check if X/Square button (bit 5) is pressed for failure
@@ -676,12 +697,8 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if __name__ == "__main__":
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.scripts.server.gym_manipulator import (
EEActionSpaceConfig,
EnvWrapperConfig,
HILSerlRobotEnvConfig,
make_robot_env,
)
from lerobot.scripts.server.gym_manipulator import make_robot_env
from lerobot.common.envs.configs import HILSerlRobotEnvConfig, EEActionSpaceConfig, EnvWrapperConfig
parser = argparse.ArgumentParser(description="Test end-effector control")
parser.add_argument(

View File

@@ -1,9 +1,8 @@
import logging
import sys
import time
from dataclasses import dataclass
from threading import Lock
from typing import Annotated, Any, Dict, Optional, Tuple
from typing import Annotated, Any, Dict, Tuple
import gymnasium as gym
import numpy as np
@@ -17,66 +16,13 @@ from lerobot.common.robot_devices.control_utils import (
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.robot_devices.robots.utils import make_robot_from_config
from lerobot.common.utils.utils import log_say
from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
@dataclass
class EEActionSpaceConfig:
"""Configuration parameters for end-effector action space."""
x_step_size: float
y_step_size: float
z_step_size: float
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
use_gamepad: bool = False
@dataclass
class EnvWrapperConfig:
"""Configuration for environment wrappers."""
display_cameras: bool = False
delta_action: float = 0.1
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None
control_time_s: float = 20.0
fixed_reset_joint_positions: Optional[Any] = None
reset_time_s: float = 5.0
joint_masking_action_space: Optional[Any] = None
ee_action_space_params: Optional[EEActionSpaceConfig] = None
reward_classifier_pretrained_path: Optional[str] = None
reward_classifier_config_file: Optional[str] = None
@EnvConfig.register_subclass(name="gym_manipulator")
@dataclass
class HILSerlRobotEnvConfig(EnvConfig):
"""Configuration for the HILSerlRobotEnv environment."""
robot: Optional[RobotConfig] = None
wrapper: Optional[EnvWrapperConfig] = None
fps: int = 10
mode: str = None # Either "record", "replay", None
repo_id: Optional[str] = None
dataset_root: Optional[str] = None
task: str = ""
num_episodes: int = 10 # only for record mode
episode: int = 0
device: str = "cuda"
push_to_hub: bool = True
pretrained_policy_name_or_path: Optional[str] = None
def gym_kwargs(self) -> dict:
return {}
MAX_GRIPPER_COMMAND = 25
class HILSerlRobotEnv(gym.Env):
@@ -813,9 +759,10 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
super().__init__(env)
self.ee_action_space_params = ee_action_space_params
self.use_gripper = use_gripper
# Initialize kinematics instance for the appropriate robot type
robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so100")
@@ -829,10 +776,12 @@ class EEActionWrapper(gym.ActionWrapper):
ee_action_space_params.z_step_size,
]
)
if self.use_gripper:
action_space_bounds = np.concatenate([action_space_bounds, [1.0]])
ee_action_space = gym.spaces.Box(
low=-action_space_bounds,
high=action_space_bounds,
shape=(3,),
shape=(3 + int(self.use_gripper),),
dtype=np.float32,
)
if isinstance(self.action_space, gym.spaces.Tuple):
@@ -848,6 +797,10 @@ class EEActionWrapper(gym.ActionWrapper):
if isinstance(action, tuple):
action, _ = action
if self.use_gripper:
gripper_command = action[-1]
action = action[:-1]
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos)
if isinstance(action, torch.Tensor):
@@ -863,6 +816,12 @@ class EEActionWrapper(gym.ActionWrapper):
position_only=True,
fk_func=self.fk_function,
)
if self.use_gripper:
gripper_command = gripper_command * MAX_GRIPPER_COMMAND
gripper_state = self.unwrapped.robot.follower_arms["main"].read("Present_Position")[-1]
gripper_action = np.clip(gripper_state + gripper_command, 0, MAX_GRIPPER_COMMAND)
target_joint_pos[-1] = gripper_action
return target_joint_pos, is_intervention
@@ -912,6 +871,7 @@ class GamepadControlWrapper(gym.Wrapper):
x_step_size=1.0,
y_step_size=1.0,
z_step_size=1.0,
use_gripper=False,
auto_reset=False,
input_threshold=0.001,
):
@@ -948,6 +908,7 @@ class GamepadControlWrapper(gym.Wrapper):
z_step_size=z_step_size,
)
self.auto_reset = auto_reset
self.use_gripper = use_gripper
self.input_threshold = input_threshold
self.controller.start()
@@ -977,6 +938,15 @@ class GamepadControlWrapper(gym.Wrapper):
# Create action from gamepad input
gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32)
if self.use_gripper:
gripper_command = self.controller.gripper_command()
if gripper_command == "open":
gamepad_action = np.concatenate([gamepad_action, [1.0]])
elif gripper_command == "close":
gamepad_action = np.concatenate([gamepad_action, [-1.0]])
else:
gamepad_action = np.concatenate([gamepad_action, [0.0]])
# Check episode ending buttons
# We'll rely on controller.get_episode_end_status() which returns "success", "failure", or None
episode_end_status = self.controller.get_episode_end_status()
@@ -1023,6 +993,7 @@ class GamepadControlWrapper(gym.Wrapper):
final_action = (torch.from_numpy(gamepad_action), False)
else:
final_action = torch.from_numpy(gamepad_action)
else:
# Use the original action
final_action = action
@@ -1138,7 +1109,11 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps)
if cfg.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = EEActionWrapper(
env=env,
ee_action_space_params=cfg.wrapper.ee_action_space_params,
use_gripper=cfg.wrapper.use_gripper,
)
if cfg.wrapper.ee_action_space_params is not None and cfg.wrapper.ee_action_space_params.use_gamepad:
# env = ActionScaleWrapper(env=env, ee_action_space_params=cfg.wrapper.ee_action_space_params)
env = GamepadControlWrapper(
@@ -1146,6 +1121,7 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
x_step_size=cfg.wrapper.ee_action_space_params.x_step_size,
y_step_size=cfg.wrapper.ee_action_space_params.y_step_size,
z_step_size=cfg.wrapper.ee_action_space_params.z_step_size,
use_gripper=cfg.wrapper.use_gripper,
)
else:
env = KeyboardInterfaceWrapper(env=env)
@@ -1184,7 +1160,7 @@ def get_classifier(cfg):
return model
def record_dataset(env, policy, cfg: HILSerlRobotEnvConfig):
def record_dataset(env, policy, cfg):
"""
Record a dataset of robot interactions using either a policy or teleop.