168 lines
7.8 KiB
Python
168 lines
7.8 KiB
Python
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
|
# ruff: noqa
|
|
import collections
|
|
import time
|
|
|
|
import dm_env
|
|
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
|
from interbotix_xs_msgs.msg import JointSingleCommand
|
|
import numpy as np
|
|
|
|
from examples.aloha_real import constants
|
|
from examples.aloha_real import robot_utils
|
|
|
|
|
|
class RealEnv:
|
|
"""
|
|
Environment for real robot bi-manual manipulation
|
|
Action space: [left_arm_qpos (6), # absolute joint position
|
|
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
|
right_arm_qpos (6), # absolute joint position
|
|
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
|
|
|
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
|
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
|
right_arm_qpos (6), # absolute joint position
|
|
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
|
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
|
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
|
right_arm_qvel (6), # absolute joint velocity (rad)
|
|
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
|
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
|
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
|
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
|
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
|
"""
|
|
|
|
def __init__(self, init_node, *, setup_robots: bool = True):
|
|
self.puppet_bot_left = InterbotixManipulatorXS(
|
|
robot_model="vx300s",
|
|
group_name="arm",
|
|
gripper_name="gripper",
|
|
robot_name="puppet_left",
|
|
init_node=init_node,
|
|
)
|
|
self.puppet_bot_right = InterbotixManipulatorXS(
|
|
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
|
)
|
|
if setup_robots:
|
|
self.setup_robots()
|
|
|
|
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
|
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
|
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
|
self.gripper_command = JointSingleCommand(name="gripper")
|
|
|
|
def setup_robots(self):
|
|
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
|
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
|
|
|
def get_qpos(self):
|
|
left_qpos_raw = self.recorder_left.qpos
|
|
right_qpos_raw = self.recorder_right.qpos
|
|
left_arm_qpos = left_qpos_raw[:6]
|
|
right_arm_qpos = right_qpos_raw[:6]
|
|
left_gripper_qpos = [
|
|
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
|
] # this is position not joint
|
|
right_gripper_qpos = [
|
|
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
|
] # this is position not joint
|
|
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
|
|
|
def get_qvel(self):
|
|
left_qvel_raw = self.recorder_left.qvel
|
|
right_qvel_raw = self.recorder_right.qvel
|
|
left_arm_qvel = left_qvel_raw[:6]
|
|
right_arm_qvel = right_qvel_raw[:6]
|
|
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
|
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
|
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
|
|
|
def get_effort(self):
|
|
left_effort_raw = self.recorder_left.effort
|
|
right_effort_raw = self.recorder_right.effort
|
|
left_robot_effort = left_effort_raw[:7]
|
|
right_robot_effort = right_effort_raw[:7]
|
|
return np.concatenate([left_robot_effort, right_robot_effort])
|
|
|
|
def get_images(self):
|
|
return self.image_recorder.get_images()
|
|
|
|
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
|
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
|
self.gripper_command.cmd = left_gripper_desired_joint
|
|
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
|
|
|
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
|
right_gripper_desired_pos_normalized
|
|
)
|
|
self.gripper_command.cmd = right_gripper_desired_joint
|
|
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
|
|
|
def _reset_joints(self):
|
|
# reset_position = START_ARM_POSE[:6]
|
|
reset_position = [0, -1.5, 1.5, 0, 0, 0]
|
|
robot_utils.move_arms(
|
|
[self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1
|
|
)
|
|
|
|
def _reset_gripper(self):
|
|
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
|
robot_utils.move_grippers(
|
|
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
|
)
|
|
robot_utils.move_grippers(
|
|
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
|
)
|
|
|
|
def get_observation(self):
|
|
obs = collections.OrderedDict()
|
|
obs["qpos"] = self.get_qpos()
|
|
obs["qvel"] = self.get_qvel()
|
|
obs["effort"] = self.get_effort()
|
|
obs["images"] = self.get_images()
|
|
return obs
|
|
|
|
def get_reward(self):
|
|
return 0
|
|
|
|
def reset(self, *, fake=False):
|
|
if not fake:
|
|
# Reboot puppet robot gripper motors
|
|
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
|
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
|
self._reset_joints()
|
|
self._reset_gripper()
|
|
return dm_env.TimeStep(
|
|
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
|
)
|
|
|
|
def step(self, action):
|
|
state_len = int(len(action) / 2)
|
|
left_action = action[:state_len]
|
|
right_action = action[state_len:]
|
|
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
|
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
|
self.set_gripper_pose(left_action[-1], right_action[-1])
|
|
time.sleep(constants.DT)
|
|
return dm_env.TimeStep(
|
|
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
|
)
|
|
|
|
|
|
def get_action(master_bot_left, master_bot_right):
|
|
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
|
# Arm actions
|
|
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
|
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
|
# Gripper actions
|
|
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
|
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
|
|
|
return action
|
|
|
|
|
|
def make_real_env(init_node, *, setup_robots: bool = True) -> RealEnv:
|
|
return RealEnv(init_node, setup_robots=setup_robots)
|