forked from tangger/lerobot
Added option to add current readings to the state of the policy
This commit is contained in:
@@ -192,6 +192,7 @@ class EnvWrapperConfig:
|
|||||||
display_cameras: bool = False
|
display_cameras: bool = False
|
||||||
use_relative_joint_positions: bool = True
|
use_relative_joint_positions: bool = True
|
||||||
add_joint_velocity_to_observation: bool = False
|
add_joint_velocity_to_observation: bool = False
|
||||||
|
add_current_to_observation: bool = False
|
||||||
add_ee_pose_to_observation: bool = False
|
add_ee_pose_to_observation: bool = False
|
||||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||||
resize_size: Optional[Tuple[int, int]] = None
|
resize_size: Optional[Tuple[int, int]] = None
|
||||||
|
|||||||
@@ -258,24 +258,24 @@ class GamepadController(InputController):
|
|||||||
elif event.button == 0:
|
elif event.button == 0:
|
||||||
self.episode_end_status = "rerecord_episode"
|
self.episode_end_status = "rerecord_episode"
|
||||||
|
|
||||||
# RB button (6) for opening gripper
|
# RB button (6) for closing gripper
|
||||||
elif event.button == 6:
|
elif event.button == 6:
|
||||||
self.open_gripper_command = True
|
|
||||||
|
|
||||||
# LT button (7) for closing gripper
|
|
||||||
elif event.button == 7:
|
|
||||||
self.close_gripper_command = True
|
self.close_gripper_command = True
|
||||||
|
|
||||||
|
# LT button (7) for openning gripper
|
||||||
|
elif event.button == 7:
|
||||||
|
self.open_gripper_command = True
|
||||||
|
|
||||||
# Reset episode status on button release
|
# Reset episode status on button release
|
||||||
elif event.type == pygame.JOYBUTTONUP:
|
elif event.type == pygame.JOYBUTTONUP:
|
||||||
if event.button in [0, 2, 3]:
|
if event.button in [0, 2, 3]:
|
||||||
self.episode_end_status = None
|
self.episode_end_status = None
|
||||||
|
|
||||||
elif event.button == 6:
|
elif event.button == 6:
|
||||||
self.open_gripper_command = False
|
self.close_gripper_command = False
|
||||||
|
|
||||||
elif event.button == 7:
|
elif event.button == 7:
|
||||||
self.close_gripper_command = False
|
self.open_gripper_command = False
|
||||||
|
|
||||||
# Check for RB button (typically button 5) for intervention flag
|
# Check for RB button (typically button 5) for intervention flag
|
||||||
if self.joystick.get_button(5):
|
if self.joystick.get_button(5):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from lerobot.configs import parser
|
|||||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
MAX_GRIPPER_COMMAND = 25
|
MAX_GRIPPER_COMMAND = 40
|
||||||
|
|
||||||
|
|
||||||
class HILSerlRobotEnv(gym.Env):
|
class HILSerlRobotEnv(gym.Env):
|
||||||
@@ -304,7 +304,7 @@ class HILSerlRobotEnv(gym.Env):
|
|||||||
|
|
||||||
|
|
||||||
class AddJointVelocityToObservation(gym.ObservationWrapper):
|
class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||||
def __init__(self, env, joint_velocity_limits=100.0, fps=30):
|
def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
# Extend observation space to include joint velocities
|
# Extend observation space to include joint velocities
|
||||||
@@ -312,12 +312,12 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
|||||||
old_high = self.observation_space["observation.state"].high
|
old_high = self.observation_space["observation.state"].high
|
||||||
old_shape = self.observation_space["observation.state"].shape
|
old_shape = self.observation_space["observation.state"].shape
|
||||||
|
|
||||||
self.last_joint_positions = np.zeros(old_shape)
|
self.last_joint_positions = np.zeros(num_dof)
|
||||||
|
|
||||||
new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits])
|
new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits])
|
||||||
new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits])
|
new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits])
|
||||||
|
|
||||||
new_shape = (old_shape[0] * 2,)
|
new_shape = (old_shape[0] + num_dof,)
|
||||||
|
|
||||||
self.observation_space["observation.state"] = gym.spaces.Box(
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
||||||
low=new_low,
|
low=new_low,
|
||||||
@@ -337,6 +337,37 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
|
||||||
|
class AddCurrentToObservation(gym.ObservationWrapper):
|
||||||
|
def __init__(self, env, max_current=500, num_dof=6):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
# Extend observation space to include joint velocities
|
||||||
|
old_low = self.observation_space["observation.state"].low
|
||||||
|
old_high = self.observation_space["observation.state"].high
|
||||||
|
old_shape = self.observation_space["observation.state"].shape
|
||||||
|
|
||||||
|
|
||||||
|
new_low = np.concatenate([old_low, np.zeros(num_dof)])
|
||||||
|
new_high = np.concatenate([old_high, np.ones(num_dof) * max_current])
|
||||||
|
|
||||||
|
new_shape = (old_shape[0] + num_dof,)
|
||||||
|
|
||||||
|
self.observation_space["observation.state"] = gym.spaces.Box(
|
||||||
|
low=new_low,
|
||||||
|
high=new_high,
|
||||||
|
shape=new_shape,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def observation(self, observation):
|
||||||
|
present_current = self.unwrapped.robot.follower_arms["main"].read("Present_Current").astype(np.float32)
|
||||||
|
observation["observation.state"] = torch.cat(
|
||||||
|
[observation["observation.state"], torch.from_numpy(present_current)], dim=-1
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
|
||||||
class ActionRepeatWrapper(gym.Wrapper):
|
class ActionRepeatWrapper(gym.Wrapper):
|
||||||
def __init__(self, env, nb_repeat: int = 1):
|
def __init__(self, env, nb_repeat: int = 1):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
@@ -553,6 +584,9 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
|||||||
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
# TODO(michel-aractingi): Bug in resize, it returns values outside [0, 1]
|
||||||
obs[k] = obs[k].clamp(0.0, 1.0)
|
obs[k] = obs[k].clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
# import cv2
|
||||||
|
# cv2.imwrite(f"tmp_img/{k}.jpg", obs[k].squeeze(0).permute(1, 2, 0).cpu().numpy() * 255)
|
||||||
|
|
||||||
# Check for NaNs after processing
|
# Check for NaNs after processing
|
||||||
if torch.isnan(obs[k]).any():
|
if torch.isnan(obs[k]).any():
|
||||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||||
@@ -812,16 +846,26 @@ class GripperPenaltyWrapper(gym.RewardWrapper):
|
|||||||
|
|
||||||
|
|
||||||
class GripperActionWrapper(gym.ActionWrapper):
|
class GripperActionWrapper(gym.ActionWrapper):
|
||||||
def __init__(self, env, quantization_threshold: float = 0.2):
|
def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
self.quantization_threshold = quantization_threshold
|
self.quantization_threshold = quantization_threshold
|
||||||
|
self.gripper_sleep = gripper_sleep
|
||||||
|
self.last_gripper_action_time = 0.0
|
||||||
|
self.last_gripper_action = None
|
||||||
|
|
||||||
def action(self, action):
|
def action(self, action):
|
||||||
is_intervention = False
|
is_intervention = False
|
||||||
if isinstance(action, tuple):
|
if isinstance(action, tuple):
|
||||||
action, is_intervention = action
|
action, is_intervention = action
|
||||||
gripper_command = action[-1]
|
|
||||||
|
|
||||||
|
if self.gripper_sleep > 0.0:
|
||||||
|
if self.last_gripper_action is not None and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep:
|
||||||
|
action[-1] = self.last_gripper_action
|
||||||
|
else:
|
||||||
|
self.last_gripper_action_time = time.perf_counter()
|
||||||
|
self.last_gripper_action = action[-1]
|
||||||
|
|
||||||
|
gripper_command = action[-1]
|
||||||
# Gripper actions are between 0, 2
|
# Gripper actions are between 0, 2
|
||||||
# we want to quantize them to -1, 0 or 1
|
# we want to quantize them to -1, 0 or 1
|
||||||
gripper_command = gripper_command - 1.0
|
gripper_command = gripper_command - 1.0
|
||||||
@@ -837,6 +881,12 @@ class GripperActionWrapper(gym.ActionWrapper):
|
|||||||
action[-1] = gripper_action.item()
|
action[-1] = gripper_action.item()
|
||||||
return action, is_intervention
|
return action, is_intervention
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
obs, info = super().reset(**kwargs)
|
||||||
|
self.last_gripper_action_time = 0.0
|
||||||
|
self.last_gripper_action = None
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
|
||||||
class EEActionWrapper(gym.ActionWrapper):
|
class EEActionWrapper(gym.ActionWrapper):
|
||||||
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
def __init__(self, env, ee_action_space_params=None, use_gripper=False):
|
||||||
@@ -1171,6 +1221,8 @@ def make_robot_env(cfg) -> gym.vector.VectorEnv:
|
|||||||
# Add observation and image processing
|
# Add observation and image processing
|
||||||
if cfg.wrapper.add_joint_velocity_to_observation:
|
if cfg.wrapper.add_joint_velocity_to_observation:
|
||||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||||
|
if cfg.wrapper.add_current_to_observation:
|
||||||
|
env = AddCurrentToObservation(env=env)
|
||||||
if cfg.wrapper.add_ee_pose_to_observation:
|
if cfg.wrapper.add_ee_pose_to_observation:
|
||||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.wrapper.ee_action_space_params.bounds)
|
||||||
|
|
||||||
@@ -1454,7 +1506,7 @@ def main(cfg: EnvConfig):
|
|||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_s
|
dt_s = time.perf_counter() - start_loop_s
|
||||||
busy_wait(1 / cfg.fps - dt_s)
|
busy_wait(1 / cfg.fps - dt_s)
|
||||||
|
|
||||||
logging.info(f"Success after 20 steps {sucesses}")
|
logging.info(f"Success after 20 steps {sucesses}")
|
||||||
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
logging.info(f"success rate {sum(sucesses) / len(sucesses)}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user