Added option to add current readings to the state of the policy

This commit is contained in:
Michel Aractingi
2025-04-15 15:07:43 +02:00
parent 3b24ad3c84
commit 9886520d33
3 changed files with 69 additions and 16 deletions

View File

@@ -192,6 +192,7 @@ class EnvWrapperConfig:
display_cameras: bool = False
use_relative_joint_positions: bool = True
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
resize_size: Optional[Tuple[int, int]] = None

View File

@@ -258,24 +258,24 @@ class GamepadController(InputController):
elif event.button == 0:
self.episode_end_status = "rerecord_episode"
# RB button (6) for opening gripper
# RB button (6) for closing gripper
elif event.button == 6:
self.open_gripper_command = True
# LT button (7) for closing gripper
elif event.button == 7:
self.close_gripper_command = True
# LT button (7) for openning gripper
elif event.button == 7:
self.open_gripper_command = True
# Reset episode status on button release
elif event.type == pygame.JOYBUTTONUP:
if event.button in [0, 2, 3]:
self.episode_end_status = None
elif event.button == 6:
self.open_gripper_command = False
self.close_gripper_command = False
elif event.button == 7:
self.close_gripper_command = False
self.open_gripper_command = False
# Check for RB button (typically button 5) for intervention flag
if self.joystick.get_button(5):

View File

@@ -22,7 +22,7 @@ from lerobot.configs import parser
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
MAX_GRIPPER_COMMAND = 25
MAX_GRIPPER_COMMAND = 40
class HILSerlRobotEnv(gym.Env):
@@ -304,7 +304,7 @@ class HILSerlRobotEnv(gym.Env):
class AddJointVelocityToObservation(gym.ObservationWrapper):
def __init__(self, env, joint_velocity_limits=100.0, fps=30):
def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6):
super().__init__(env)
# Extend observation space to include joint velocities
@@ -312,12 +312,12 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
old_high = self.observation_space["observation.state"].high
old_shape = self.observation_space["observation.state"].shape
self.last_joint_positions = np.zeros(old_shape)
self.last_joint_positions = np.zeros(num_dof)
new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits])
new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits])
new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits])
new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits])
new_shape = (old_shape[0] * 2,)
new_shape = (old_shape[0] + num_dof,)
self.observation_space["observation.state"] = gym.spaces.Box(
low=new_low,
@@ -337,6 +337,37 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
return observation
class AddCurrentToObservation(gym.ObservationWrapper):
def __init__(self, env, max_current=500, num_dof=6):
super().__init__(env)
# Extend observation space to include joint velocities
old_low = self.observation_space["observation.state"].low
old_high = self.observation_space["observation.state"].high
old_shape = self.observation_space["observation.state"].shape
new_low = np.concatenate([old_low, np.zeros(num_dof)])
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
new_shape = (old_shape[0] + num_dof,)
self.observation_space["observation.state"] = gym.spaces.Box(
low=new_low,
high=new_high,
shape=new_shape,
dtype=np.float32,
)
def observation(self, observation):
present_current = self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
observation["observation.state"] = torch.cat(
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
)
return observation
class ActionRepeatWrapper(gym.Wrapper):
def __init__(self, env, nb_repeat: int = 1):
super().__init__(env)
@@ -553,6 +584,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
obs[k] = obs[k].clamp(0.0, 1.0)
# import cv2
# cv2.imwrite(f"tmp_img/{k}.jpg", obs[k].squeeze(0).permute(1, 2, 0).cpu().numpy() * 255)
# Check for NaNs after processing
if torch.isnan(obs[k]).any():
logging.error(f"NaN values detected in observation {k} after crop and resize")
@@ -812,16 +846,26 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
class GripperActionWrapper(gym.ActionWrapper):
def __init__(self, env, quantization_threshold: float = 0.2):
def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0):
super().__init__(env)
self.quantization_threshold = quantization_threshold
self.gripper_sleep = gripper_sleep
self.last_gripper_action_time = 0.0
self.last_gripper_action = None
def action(self, action):
is_intervention = False
if isinstance(action, tuple):
action, is_intervention = action
gripper_command = action[-1]
if self.gripper_sleep > 0.0:
if self.last_gripper_action is not None and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep:
action[-1] = self.last_gripper_action
else:
self.last_gripper_action_time = time.perf_counter()
self.last_gripper_action = action[-1]
gripper_command = action[-1]
# Gripper actions are between 0, 2
# we want to quantize them to -1, 0 or 1
gripper_command = gripper_command - 1.0
@@ -837,6 +881,12 @@ class GripperActionWrapper(gym.ActionWrapper):
action[-1] = gripper_action.item()
return action, is_intervention
def reset(self, **kwargs):
obs, info = super().reset(**kwargs)
self.last_gripper_action_time = 0.0
self.last_gripper_action = None
return obs, info
class EEActionWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
@@ -1171,6 +1221,8 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
# Add observation and image processing
if cfg.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.wrapper.add_current_to_observation:
env = AddCurrentToObservation(env=env)
if cfg.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
@@ -1454,7 +1506,7 @@ def main(cfg: EnvConfig):
dt_s = time.perf_counter() - start_loop_s
busy_wait(1 / cfg.fps - dt_s)
logging.info(f"Success after 20 steps {sucesses}")
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")