Initial commit
This commit is contained in:
275
examples/aloha_real/robot_utils.py
Normal file
275
examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
Reference in New Issue
Block a user