Added option to add current readings to the state of the policy
This commit is contained in:
@@ -192,6 +192,7 @@ class EnvWrapperConfig:
|
||||
display_cameras: bool = False
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_current_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
|
||||
@@ -258,24 +258,24 @@ class GamepadController(InputController):
|
||||
elif event.button == 0:
|
||||
self.episode_end_status = "rerecord_episode"
|
||||
|
||||
# RB button (6) for opening gripper
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# LT button (7) for closing gripper
|
||||
elif event.button == 7:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for openning gripper
|
||||
elif event.button == 7:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
self.close_gripper_command = False
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
|
||||
@@ -22,7 +22,7 @@ from lerobot.configs import parser
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
MAX_GRIPPER_COMMAND = 25
|
||||
MAX_GRIPPER_COMMAND = 40
|
||||
|
||||
|
||||
class HILSerlRobotEnv(gym.Env):
|
||||
@@ -304,7 +304,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
|
||||
|
||||
class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||
def __init__(self, env, joint_velocity_limits=100.0, fps=30):
|
||||
def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6):
|
||||
super().__init__(env)
|
||||
|
||||
# Extend observation space to include joint velocities
|
||||
@@ -312,12 +312,12 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||
old_high = self.observation_space["observation.state"].high
|
||||
old_shape = self.observation_space["observation.state"].shape
|
||||
|
||||
self.last_joint_positions = np.zeros(old_shape)
|
||||
self.last_joint_positions = np.zeros(num_dof)
|
||||
|
||||
new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits])
|
||||
new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits])
|
||||
new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits])
|
||||
new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits])
|
||||
|
||||
new_shape = (old_shape[0] * 2,)
|
||||
new_shape = (old_shape[0] + num_dof,)
|
||||
|
||||
self.observation_space["observation.state"] = gym.spaces.Box(
|
||||
low=new_low,
|
||||
@@ -337,6 +337,37 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||
return observation
|
||||
|
||||
|
||||
class AddCurrentToObservation(gym.ObservationWrapper):
|
||||
def __init__(self, env, max_current=500, num_dof=6):
|
||||
super().__init__(env)
|
||||
|
||||
# Extend observation space to include joint velocities
|
||||
old_low = self.observation_space["observation.state"].low
|
||||
old_high = self.observation_space["observation.state"].high
|
||||
old_shape = self.observation_space["observation.state"].shape
|
||||
|
||||
|
||||
new_low = np.concatenate([old_low, np.zeros(num_dof)])
|
||||
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
|
||||
|
||||
new_shape = (old_shape[0] + num_dof,)
|
||||
|
||||
self.observation_space["observation.state"] = gym.spaces.Box(
|
||||
low=new_low,
|
||||
high=new_high,
|
||||
shape=new_shape,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def observation(self, observation):
|
||||
present_current = self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
|
||||
observation["observation.state"] = torch.cat(
|
||||
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
|
||||
)
|
||||
return observation
|
||||
|
||||
|
||||
class ActionRepeatWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_repeat: int = 1):
|
||||
super().__init__(env)
|
||||
@@ -553,6 +584,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
||||
obs[k] = obs[k].clamp(0.0, 1.0)
|
||||
|
||||
# import cv2
|
||||
# cv2.imwrite(f"tmp_img/{k}.jpg", obs[k].squeeze(0).permute(1, 2, 0).cpu().numpy() * 255)
|
||||
|
||||
# Check for NaNs after processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||
@@ -812,16 +846,26 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
||||
|
||||
|
||||
class GripperActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
||||
def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0):
|
||||
super().__init__(env)
|
||||
self.quantization_threshold = quantization_threshold
|
||||
self.gripper_sleep = gripper_sleep
|
||||
self.last_gripper_action_time = 0.0
|
||||
self.last_gripper_action = None
|
||||
|
||||
def action(self, action):
|
||||
is_intervention = False
|
||||
if isinstance(action, tuple):
|
||||
action, is_intervention = action
|
||||
gripper_command = action[-1]
|
||||
|
||||
if self.gripper_sleep > 0.0:
|
||||
if self.last_gripper_action is not None and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep:
|
||||
action[-1] = self.last_gripper_action
|
||||
else:
|
||||
self.last_gripper_action_time = time.perf_counter()
|
||||
self.last_gripper_action = action[-1]
|
||||
|
||||
gripper_command = action[-1]
|
||||
# Gripper actions are between 0, 2
|
||||
# we want to quantize them to -1, 0 or 1
|
||||
gripper_command = gripper_command - 1.0
|
||||
@@ -837,6 +881,12 @@ class GripperActionWrapper(gym.ActionWrapper):
|
||||
action[-1] = gripper_action.item()
|
||||
return action, is_intervention
|
||||
|
||||
def reset(self, **kwargs):
|
||||
obs, info = super().reset(**kwargs)
|
||||
self.last_gripper_action_time = 0.0
|
||||
self.last_gripper_action = None
|
||||
return obs, info
|
||||
|
||||
|
||||
class EEActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||
@@ -1171,6 +1221,8 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
||||
# Add observation and image processing
|
||||
if cfg.wrapper.add_joint_velocity_to_observation:
|
||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||
if cfg.wrapper.add_current_to_observation:
|
||||
env = AddCurrentToObservation(env=env)
|
||||
if cfg.wrapper.add_ee_pose_to_observation:
|
||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
||||
|
||||
@@ -1454,7 +1506,7 @@ def main(cfg: EnvConfig):
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_s
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
|
||||
logging.info(f"Success after 20 steps {sucesses}")
|
||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user