diff --git a/collect_data/README.MD b/collect_data/README.MD deleted file mode 100644 index 9e4d14a..0000000 --- a/collect_data/README.MD +++ /dev/null @@ -1,3 +0,0 @@ -python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data - -python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1 \ No newline at end of file diff --git a/collect_data/__pycache__/agilex_robot.cpython-310.pyc b/collect_data/__pycache__/agilex_robot.cpython-310.pyc deleted file mode 100644 index b8b4a33..0000000 Binary files a/collect_data/__pycache__/agilex_robot.cpython-310.pyc and /dev/null differ diff --git a/collect_data/__pycache__/robot_components.cpython-310.pyc b/collect_data/__pycache__/robot_components.cpython-310.pyc deleted file mode 100644 index 2f08b67..0000000 Binary files a/collect_data/__pycache__/robot_components.cpython-310.pyc and /dev/null differ diff --git a/collect_data/__pycache__/ros_robot.cpython-310.pyc b/collect_data/__pycache__/ros_robot.cpython-310.pyc deleted file mode 100644 index b95e9b2..0000000 Binary files a/collect_data/__pycache__/ros_robot.cpython-310.pyc and /dev/null differ diff --git a/collect_data/__pycache__/rosoperator.cpython-310.pyc b/collect_data/__pycache__/rosoperator.cpython-310.pyc deleted file mode 100644 index 7c51e60..0000000 Binary files a/collect_data/__pycache__/rosoperator.cpython-310.pyc and /dev/null differ diff --git a/collect_data/__pycache__/rosrobot.cpython-310.pyc b/collect_data/__pycache__/rosrobot.cpython-310.pyc deleted file mode 100644 index 67ba2c5..0000000 Binary files a/collect_data/__pycache__/rosrobot.cpython-310.pyc and /dev/null differ diff --git a/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc b/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc deleted file mode 100644 index fc77f3f..0000000 Binary files a/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc and /dev/null differ diff --git a/collect_data/agilex.yaml b/collect_data/agilex.yaml deleted file mode 100644 index 703b7e2..0000000 --- a/collect_data/agilex.yaml +++ /dev/null @@ -1,146 +0,0 @@ -robot_type: aloha_agilex -ros_node_name: record_episodes -cameras: - cam_front: - img_topic_name: /camera_f/color/image_raw - depth_topic_name: /camera_f/depth/image_raw - width: 480 - height: 640 - rgb_shape: [480, 640, 3] - cam_left: - img_topic_name: /camera_l/color/image_raw - depth_topic_name: /camera_l/depth/image_raw - rgb_shape: [480, 640, 3] - width: 480 - height: 640 - cam_right: - img_topic_name: /camera_r/color/image_raw - depth_topic_name: /camera_r/depth/image_raw - rgb_shape: [480, 640, 3] - width: 480 - height: 640 - cam_high: - img_topic_name: /camera/color/image_raw - depth_topic_name: /camera/depth/image_rect_raw - rgb_shape: [480, 640, 3] - width: 480 - height: 640 - -arm: - master_left: - topic_name: /master/joint_left - motors: [ - "left_joint0", - "left_joint1", - "left_joint2", - "left_joint3", - "left_joint4", - "left_joint5", - "left_none" - ] - master_right: - topic_name: /master/joint_right - motors: [ - "right_joint0", - "right_joint1", - "right_joint2", - "right_joint3", - "right_joint4", - "right_joint5", - "right_none" - ] - puppet_left: - topic_name: /puppet/joint_left - motors: [ - "left_joint0", - "left_joint1", - "left_joint2", - "left_joint3", - "left_joint4", - "left_joint5", - "left_none" - ] - puppet_right: - topic_name: /puppet/joint_right - motors: [ - "right_joint0", - "right_joint1", - "right_joint2", - "right_joint3", - "right_joint4", - "right_joint5", - "right_none" - ] - -# follow the joint name in ros -state: - motors: [ - "left_joint0", - "left_joint1", - "left_joint2", - "left_joint3", - "left_joint4", - "left_joint5", - "left_none", - "right_joint0", - "right_joint1", - "right_joint2", - "right_joint3", - "right_joint4", - "right_joint5", - "right_none" - ] - -velocity: - motors: [ - "left_joint0", - "left_joint1", - "left_joint2", - "left_joint3", - "left_joint4", - "left_joint5", - "left_none", - "right_joint0", - "right_joint1", - "right_joint2", - "right_joint3", - "right_joint4", - "right_joint5", - "right_none" - ] - -effort: - motors: [ - "left_joint0", - "left_joint1", - "left_joint2", - "left_joint3", - "left_joint4", - "left_joint5", - "left_none", - "right_joint0", - "right_joint1", - "right_joint2", - "right_joint3", - "right_joint4", - "right_joint5", - "right_none" - ] - -action: - motors: [ - "left_joint0", - "left_joint1", - "left_joint2", - "left_joint3", - "left_joint4", - "left_joint5", - "left_none", - "right_joint0", - "right_joint1", - "right_joint2", - "right_joint3", - "right_joint4", - "right_joint5", - "right_none" - ] diff --git a/collect_data/aloha_mobile.py b/collect_data/aloha_mobile.py new file mode 100644 index 0000000..538ecb5 --- /dev/null +++ b/collect_data/aloha_mobile.py @@ -0,0 +1,305 @@ +import cv2 +import numpy as np +import dm_env + +import collections +from collections import deque + +import rospy +from sensor_msgs.msg import JointState +from sensor_msgs.msg import Image +from nav_msgs.msg import Odometry +from cv_bridge import CvBridge +from utils import display_camera_grid, save_data + + +class AlohaRobotRos: + def __init__(self, args): + self.robot_base_deque = None + self.puppet_arm_right_deque = None + self.puppet_arm_left_deque = None + self.master_arm_right_deque = None + self.master_arm_left_deque = None + self.img_front_deque = None + self.img_right_deque = None + self.img_left_deque = None + self.img_front_depth_deque = None + self.img_right_depth_deque = None + self.img_left_depth_deque = None + self.bridge = None + self.args = args + self.init() + self.init_ros() + + def init(self): + self.bridge = CvBridge() + self.img_left_deque = deque() + self.img_right_deque = deque() + self.img_front_deque = deque() + self.img_left_depth_deque = deque() + self.img_right_depth_deque = deque() + self.img_front_depth_deque = deque() + self.master_arm_left_deque = deque() + self.master_arm_right_deque = deque() + self.puppet_arm_left_deque = deque() + self.puppet_arm_right_deque = deque() + self.robot_base_deque = deque() + + def get_frame(self): + print(len(self.img_left_deque), len(self.img_right_deque), len(self.img_front_deque), + len(self.img_left_depth_deque), len(self.img_right_depth_deque), len(self.img_front_depth_deque)) + if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \ + (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)): + return False + if self.args.use_depth_image: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(), + self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()]) + else: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()]) + + if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.master_arm_left_deque) == 0 or self.master_arm_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.master_arm_right_deque) == 0 or self.master_arm_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): + return False + + while self.img_left_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_deque.popleft() + img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough') + # print("img_left:", img_left.shape) + + while self.img_right_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_deque.popleft() + img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough') + + while self.img_front_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_deque.popleft() + img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough') + + while self.master_arm_left_deque[0].header.stamp.to_sec() < frame_time: + self.master_arm_left_deque.popleft() + master_arm_left = self.master_arm_left_deque.popleft() + + while self.master_arm_right_deque[0].header.stamp.to_sec() < frame_time: + self.master_arm_right_deque.popleft() + master_arm_right = self.master_arm_right_deque.popleft() + + while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_left_deque.popleft() + puppet_arm_left = self.puppet_arm_left_deque.popleft() + + while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_right_deque.popleft() + puppet_arm_right = self.puppet_arm_right_deque.popleft() + + img_left_depth = None + if self.args.use_depth_image: + while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_depth_deque.popleft() + img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough') + top, bottom, left, right = 40, 40, 0, 0 + img_left_depth = cv2.copyMakeBorder(img_left_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) + + img_right_depth = None + if self.args.use_depth_image: + while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_depth_deque.popleft() + img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough') + top, bottom, left, right = 40, 40, 0, 0 + img_right_depth = cv2.copyMakeBorder(img_right_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) + + img_front_depth = None + if self.args.use_depth_image: + while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_depth_deque.popleft() + img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough') + top, bottom, left, right = 40, 40, 0, 0 + img_front_depth = cv2.copyMakeBorder(img_front_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) + + robot_base = None + if self.args.use_robot_base: + while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: + self.robot_base_deque.popleft() + robot_base = self.robot_base_deque.popleft() + + return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) + + def img_left_callback(self, msg): + if len(self.img_left_deque) >= 2000: + self.img_left_deque.popleft() + self.img_left_deque.append(msg) + + def img_right_callback(self, msg): + if len(self.img_right_deque) >= 2000: + self.img_right_deque.popleft() + self.img_right_deque.append(msg) + + def img_front_callback(self, msg): + if len(self.img_front_deque) >= 2000: + self.img_front_deque.popleft() + self.img_front_deque.append(msg) + + def img_left_depth_callback(self, msg): + # import pdb + # pdb.set_trace() + if len(self.img_left_depth_deque) >= 2000: + self.img_left_depth_deque.popleft() + self.img_left_depth_deque.append(msg) + + def img_right_depth_callback(self, msg): + if len(self.img_right_depth_deque) >= 2000: + self.img_right_depth_deque.popleft() + self.img_right_depth_deque.append(msg) + + def img_front_depth_callback(self, msg): + if len(self.img_front_depth_deque) >= 2000: + self.img_front_depth_deque.popleft() + self.img_front_depth_deque.append(msg) + + def master_arm_left_callback(self, msg): + if len(self.master_arm_left_deque) >= 2000: + self.master_arm_left_deque.popleft() + self.master_arm_left_deque.append(msg) + + def master_arm_right_callback(self, msg): + if len(self.master_arm_right_deque) >= 2000: + self.master_arm_right_deque.popleft() + self.master_arm_right_deque.append(msg) + + def puppet_arm_left_callback(self, msg): + if len(self.puppet_arm_left_deque) >= 2000: + self.puppet_arm_left_deque.popleft() + self.puppet_arm_left_deque.append(msg) + + def puppet_arm_right_callback(self, msg): + if len(self.puppet_arm_right_deque) >= 2000: + self.puppet_arm_right_deque.popleft() + self.puppet_arm_right_deque.append(msg) + + def robot_base_callback(self, msg): + if len(self.robot_base_deque) >= 2000: + self.robot_base_deque.popleft() + self.robot_base_deque.append(msg) + + def init_ros(self): + rospy.init_node('record_episodes', anonymous=True) + rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True) + if self.args.use_depth_image: + rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True) + + rospy.Subscriber(self.args.master_arm_left_topic, JointState, self.master_arm_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.master_arm_right_topic, JointState, self.master_arm_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True) + # rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True) + + def process(self): + timesteps = [] + actions = [] + # 图像数据 + image = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8) + image_dict = dict() + for cam_name in self.args.camera_names: + image_dict[cam_name] = image + count = 0 + + # input_key = input("please input s:") + # while input_key != 's' and not rospy.is_shutdown(): + # input_key = input("please input s:") + + rate = rospy.Rate(self.args.frame_rate) + print_flag = True + + while (count < self.args.max_timesteps + 1) and not rospy.is_shutdown(): + # 2 收集数据 + result = self.get_frame() + # import pdb + # pdb.set_trace() + if not result: + if print_flag: + print("syn fail\n") + print_flag = False + rate.sleep() + continue + print_flag = True + count += 1 + (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) = result + # 2.1 图像信息 + image_dict = dict() + image_dict[self.args.camera_names[0]] = img_front + image_dict[self.args.camera_names[1]] = img_left + image_dict[self.args.camera_names[2]] = img_right + + # import pdb + # pdb.set_trace() + display_camera_grid(image_dict) + + # 2.2 从臂的信息从臂的状态 机械臂示教模式时 会自动订阅 + obs = collections.OrderedDict() # 有序的字典 + obs['images'] = image_dict + if self.args.use_depth_image: + image_dict_depth = dict() + image_dict_depth[self.args.camera_names[0]] = img_front_depth + image_dict_depth[self.args.camera_names[1]] = img_left_depth + image_dict_depth[self.args.camera_names[2]] = img_right_depth + obs['images_depth'] = image_dict_depth + obs['qpos'] = np.concatenate((np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0) + obs['qvel'] = np.concatenate((np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0) + obs['effort'] = np.concatenate((np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0) + if self.args.use_robot_base: + obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z] + else: + obs['base_vel'] = [0.0, 0.0] + + # 第一帧 只包含first, fisrt只保存StepType.FIRST + if count == 1: + ts = dm_env.TimeStep( + step_type=dm_env.StepType.FIRST, + reward=None, + discount=None, + observation=obs) + timesteps.append(ts) + continue + + # 时间步 + ts = dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=None, + discount=None, + observation=obs) + + # 主臂保存状态 + action = np.concatenate((np.array(master_arm_left.position), np.array(master_arm_right.position)), axis=0) + actions.append(action) + timesteps.append(ts) + print(f"\n{self.args.episode_idx} | Frame data: {count}\r", end="") + if rospy.is_shutdown(): + exit(-1) + rate.sleep() + + print("len(timesteps): ", len(timesteps)) + print("len(actions) : ", len(actions)) + return timesteps, actions diff --git a/collect_data/collect_data.py b/collect_data/collect_data.py new file mode 100755 index 0000000..f762221 --- /dev/null +++ b/collect_data/collect_data.py @@ -0,0 +1,166 @@ +import os +import time +import argparse +from aloha_mobile import AlohaRobotRos +from utils import save_data, init_keyboard_listener + + +def get_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset_dir.', + default="./data", required=False) + parser.add_argument('--task_name', action='store', type=str, help='Task name.', + default="aloha_mobile_dummy", required=False) + parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', + default=0, required=False) + parser.add_argument('--max_timesteps', action='store', type=int, help='Max_timesteps.', + default=500, required=False) + parser.add_argument('--camera_names', action='store', type=str, help='camera_names', + default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False) + parser.add_argument('--num_episodes', action='store', type=int, help='Num_episodes.', + default=1, required=False) + + + # topic name of color image + parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', + default='/camera_f/color/image_raw', required=False) + parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', + default='/camera_l/color/image_raw', required=False) + parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', + default='/camera_r/color/image_raw', required=False) + + # topic name of depth image + parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic', + default='/camera_f/depth/image_raw', required=False) + parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic', + default='/camera_l/depth/image_raw', required=False) + parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic', + default='/camera_r/depth/image_raw', required=False) + + # topic name of arm + parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', + default='/master/joint_left', required=False) + parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', + default='/master/joint_right', required=False) + parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', + default='/puppet/joint_left', required=False) + parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', + default='/puppet/joint_right', required=False) + + # topic name of robot_base + parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', + default='/odom', required=False) + + parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', + default=False, required=False) + + # collect depth image + parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image', + default=False, required=False) + + parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate', + default=30, required=False) + + args = parser.parse_args() + return args + + +def main(): + args = get_arguments() + ros_operator = AlohaRobotRos(args) + dataset_dir = os.path.join(args.dataset_dir, args.task_name) + # 确保数据集目录存在 + os.makedirs(dataset_dir, exist_ok=True) + # 单集收集模式 + if args.num_episodes == 1: + print(f"Recording single episode {args.episode_idx}...") + timesteps, actions = ros_operator.process() + + if len(actions) < args.max_timesteps: + print(f"\033[31m\nSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m\n") + return -1 + + dataset_path = os.path.join(dataset_dir, f"episode_{args.episode_idx}") + save_data(args, timesteps, actions, dataset_path) + print(f"\033[32mEpisode {args.episode_idx} saved successfully at {dataset_path}\033[0m") + return 0 + # 多集收集模式 + print(""" +\033[1;36mKeyboard Controls:\033[0m +← \033[1mLeft Arrow\033[0m: Start Recording +→ \033[1mRight Arrow\033[0m: Save Current Data +↓ \033[1mDown Arrow\033[0m: Discard Current Data +↑ \033[1mUp Arrow\033[0m: Replay Data (if implemented) +\033[1mESC\033[0m: Exit Program +""") + # 初始化键盘监听器 + listener, events = init_keyboard_listener() + episode_idx = args.episode_idx + collected_episodes = 0 + + try: + while collected_episodes < args.num_episodes: + if events["exit_early"]: + print("\033[33mOperation terminated by user\033[0m") + break + + if events["record_start"]: + # 重置事件状态,开始新的录制 + events["record_start"] = False + events["save_data"] = False + events["discard_data"] = False + + print(f"\n\033[1;32mRecording episode {episode_idx}...\033[0m") + timesteps, actions = ros_operator.process() + print(f"\033[1;33mRecorded {len(actions)} timesteps. (→ to save, ↓ to discard)\033[0m") + + # 等待用户决定保存或丢弃 + while True: + if events["save_data"]: + events["save_data"] = False + + if len(actions) < args.max_timesteps: + print(f"\033[31mSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m") + else: + dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}") + save_data(args, timesteps, actions, dataset_path) + print(f"\033[32mEpisode {episode_idx} saved successfully at {dataset_path}\033[0m") + episode_idx += 1 + collected_episodes += 1 + print(f"\033[1mProgress: {collected_episodes}/{args.num_episodes} episodes collected. (← to start new episode)\033[0m") + break + + if events["discard_data"]: + events["discard_data"] = False + print("\033[33mData discarded. Press ← to start a new recording.\033[0m") + break + + if events["exit_early"]: + print("\033[33mOperation terminated by user\033[0m") + return 0 + + time.sleep(0.1) # 减少CPU使用率 + + time.sleep(0.1) # 减少CPU使用率 + + if collected_episodes == args.num_episodes: + print(f"\n\033[1;32mData collection complete! All {args.num_episodes} episodes collected.\033[0m") + + finally: + # 确保监听器被清理 + if listener: + listener.stop() + print("Keyboard listener stopped") + + +if __name__ == '__main__': + try: + exit_code = main() + exit(exit_code if exit_code is not None else 0) + except KeyboardInterrupt: + print("\n\033[33mProgram interrupted by user\033[0m") + exit(0) + except Exception as e: + print(f"\n\033[31mError: {e}\033[0m") + +# python collect_data.py --dataset_dir ~/data --max_timesteps 500 --episode_idx 0 diff --git a/collect_data/collect_data_gui.py b/collect_data/collect_data_gui.py new file mode 100644 index 0000000..7497288 --- /dev/null +++ b/collect_data/collect_data_gui.py @@ -0,0 +1,416 @@ +import os +import sys +import time +import argparse +from aloha_mobile import AlohaRobotRos +from utils import save_data, init_keyboard_listener +from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, + QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox, + QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog, + QMessageBox, QProgressBar, QComboBox) +from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot +from PyQt5.QtGui import QFont, QIcon, QTextCursor + +class DataCollectionThread(QThread): + """处理数据收集的线程""" + update_signal = pyqtSignal(str) + progress_signal = pyqtSignal(int) + finish_signal = pyqtSignal(bool, str) + + def __init__(self, args, parent=None): + super(DataCollectionThread, self).__init__(parent) + self.args = args + self.is_running = True + self.ros_operator = None + + def run(self): + try: + self.update_signal.emit("正在初始化ROS操作...\n") + self.ros_operator = AlohaRobotRos(self.args) + dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name) + os.makedirs(dataset_dir, exist_ok=True) + + # 单集收集模式 + if self.args.num_episodes == 1: + self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n") + timesteps, actions = self.ros_operator.process() + + if len(actions) < self.args.max_timesteps: + self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n") + self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步") + return + + dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}") + save_data(self.args, timesteps, actions, dataset_path) + self.update_signal.emit(f"第 {self.args.episode_idx} 集成功保存到 {dataset_path}.\n") + self.finish_signal.emit(True, "数据收集完成") + + # 多集收集模式 + else: + self.update_signal.emit(""" +键盘控制: +← 左箭头: 开始录制 +→ 右箭头: 保存当前数据 +↓ 下箭头: 丢弃当前数据 +ESC: 退出程序 +""") + # 初始化键盘监听器 + listener, events = init_keyboard_listener() + episode_idx = self.args.episode_idx + collected_episodes = 0 + + try: + while collected_episodes < self.args.num_episodes and self.is_running: + if events["exit_early"]: + self.update_signal.emit("操作被用户终止.\n") + break + + if events["record_start"]: + # 重置事件状态,开始新的录制 + events["record_start"] = False + events["save_data"] = False + events["discard_data"] = False + + self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n") + timesteps, actions = self.ros_operator.process() + self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n") + + # 等待用户决定保存或丢弃 + while self.is_running: + if events["save_data"]: + events["save_data"] = False + + if len(actions) < self.args.max_timesteps: + self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n") + else: + dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}") + save_data(self.args, timesteps, actions, dataset_path) + self.update_signal.emit(f"第 {episode_idx} 集成功保存到 {dataset_path}.\n") + episode_idx += 1 + collected_episodes += 1 + progress_percentage = int(collected_episodes * 100 / self.args.num_episodes) + self.progress_signal.emit(progress_percentage) + self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n") + break + + if events["discard_data"]: + events["discard_data"] = False + self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n") + break + + if events["exit_early"]: + self.update_signal.emit("操作被用户终止.\n") + self.is_running = False + break + + time.sleep(0.1) # 减少CPU使用率 + + time.sleep(0.1) # 减少CPU使用率 + + if collected_episodes == self.args.num_episodes: + self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n") + self.finish_signal.emit(True, "全部数据集收集完成") + else: + self.finish_signal.emit(False, "数据收集未完成") + + finally: + # 确保监听器被清理 + if listener: + listener.stop() + self.update_signal.emit("键盘监听器已停止\n") + + except Exception as e: + self.update_signal.emit(f"错误: {str(e)}\n") + self.finish_signal.emit(False, str(e)) + + def stop(self): + self.is_running = False + self.wait() + +class AlohaDataCollectionGUI(QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("ALOHA 数据收集工具") + self.setGeometry(100, 100, 800, 700) + + # 主组件 + self.central_widget = QWidget() + self.setCentralWidget(self.central_widget) + self.main_layout = QVBoxLayout(self.central_widget) + + # 创建选项卡 + self.tab_widget = QTabWidget() + self.main_layout.addWidget(self.tab_widget) + + # 创建配置选项卡 + self.config_tab = QWidget() + self.tab_widget.addTab(self.config_tab, "配置") + + # 创建数据收集选项卡 + self.collection_tab = QWidget() + self.tab_widget.addTab(self.collection_tab, "数据收集") + + self.setup_config_tab() + self.setup_collection_tab() + + # 初始化数据收集线程 + self.collection_thread = None + + def setup_config_tab(self): + config_layout = QVBoxLayout(self.config_tab) + + # 基本配置组 + basic_group = QGroupBox("基本配置") + basic_form = QFormLayout() + + self.dataset_dir = QLineEdit("./data") + self.browse_button = QPushButton("浏览") + self.browse_button.clicked.connect(self.browse_dataset_dir) + + dir_layout = QHBoxLayout() + dir_layout.addWidget(self.dataset_dir) + dir_layout.addWidget(self.browse_button) + + self.task_name = QLineEdit("aloha_mobile_dummy") + self.episode_idx = QSpinBox() + self.episode_idx.setRange(0, 1000) + self.max_timesteps = QSpinBox() + self.max_timesteps.setRange(1, 10000) + self.max_timesteps.setValue(500) + self.num_episodes = QSpinBox() + self.num_episodes.setRange(1, 100) + self.num_episodes.setValue(1) + self.frame_rate = QSpinBox() + self.frame_rate.setRange(1, 120) + self.frame_rate.setValue(30) + + basic_form.addRow("数据集目录:", dir_layout) + basic_form.addRow("任务名称:", self.task_name) + basic_form.addRow("起始集索引:", self.episode_idx) + basic_form.addRow("最大时间步:", self.max_timesteps) + basic_form.addRow("集数:", self.num_episodes) + basic_form.addRow("帧率:", self.frame_rate) + + basic_group.setLayout(basic_form) + config_layout.addWidget(basic_group) + + # 相机话题组 + camera_group = QGroupBox("相机话题") + camera_form = QFormLayout() + + self.img_front_topic = QLineEdit('/camera_f/color/image_raw') + self.img_left_topic = QLineEdit('/camera_l/color/image_raw') + self.img_right_topic = QLineEdit('/camera_r/color/image_raw') + + self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw') + self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw') + self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw') + + self.use_depth_image = QCheckBox("使用深度图像") + + camera_form.addRow("前置相机:", self.img_front_topic) + camera_form.addRow("左腕相机:", self.img_left_topic) + camera_form.addRow("右腕相机:", self.img_right_topic) + camera_form.addRow("前置深度:", self.img_front_depth_topic) + camera_form.addRow("左腕深度:", self.img_left_depth_topic) + camera_form.addRow("右腕深度:", self.img_right_depth_topic) + camera_form.addRow("", self.use_depth_image) + + camera_group.setLayout(camera_form) + config_layout.addWidget(camera_group) + + # 机器人话题组 + robot_group = QGroupBox("机器人话题") + robot_form = QFormLayout() + + self.master_arm_left_topic = QLineEdit('/master/joint_left') + self.master_arm_right_topic = QLineEdit('/master/joint_right') + self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left') + self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right') + self.robot_base_topic = QLineEdit('/odom') + self.use_robot_base = QCheckBox("使用机器人底盘") + + robot_form.addRow("主左臂:", self.master_arm_left_topic) + robot_form.addRow("主右臂:", self.master_arm_right_topic) + robot_form.addRow("从左臂:", self.puppet_arm_left_topic) + robot_form.addRow("从右臂:", self.puppet_arm_right_topic) + robot_form.addRow("底盘:", self.robot_base_topic) + robot_form.addRow("", self.use_robot_base) + + robot_group.setLayout(robot_form) + config_layout.addWidget(robot_group) + + # 相机名称配置 + camera_names_group = QGroupBox("相机名称") + camera_names_layout = QVBoxLayout() + + self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist'] + self.camera_checkboxes = {} + + for cam_name in self.camera_names: + self.camera_checkboxes[cam_name] = QCheckBox(cam_name) + self.camera_checkboxes[cam_name].setChecked(True) + camera_names_layout.addWidget(self.camera_checkboxes[cam_name]) + + camera_names_group.setLayout(camera_names_layout) + config_layout.addWidget(camera_names_group) + + # 保存配置按钮 + self.save_config_button = QPushButton("保存配置") + self.save_config_button.clicked.connect(self.save_config) + config_layout.addWidget(self.save_config_button) + + def setup_collection_tab(self): + collection_layout = QVBoxLayout(self.collection_tab) + + # 当前配置展示 + config_group = QGroupBox("当前配置") + self.config_text = QTextEdit() + self.config_text.setReadOnly(True) + config_layout = QVBoxLayout() + config_layout.addWidget(self.config_text) + config_group.setLayout(config_layout) + collection_layout.addWidget(config_group) + + # 操作按钮 + buttons_layout = QHBoxLayout() + + self.start_button = QPushButton("开始收集") + self.start_button.setIcon(QIcon.fromTheme("media-playback-start")) + self.start_button.clicked.connect(self.start_collection) + + self.stop_button = QPushButton("停止") + self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop")) + self.stop_button.setEnabled(False) + self.stop_button.clicked.connect(self.stop_collection) + + buttons_layout.addWidget(self.start_button) + buttons_layout.addWidget(self.stop_button) + collection_layout.addLayout(buttons_layout) + + # 进度条 + self.progress_bar = QProgressBar() + self.progress_bar.setValue(0) + collection_layout.addWidget(self.progress_bar) + + # 日志输出 + log_group = QGroupBox("操作日志") + self.log_text = QTextEdit() + self.log_text.setReadOnly(True) + log_layout = QVBoxLayout() + log_layout.addWidget(self.log_text) + log_group.setLayout(log_layout) + collection_layout.addWidget(log_group) + + def browse_dataset_dir(self): + directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text()) + if directory: + self.dataset_dir.setText(directory) + + def save_config(self): + # 更新配置显示 + selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()] + + config_text = f""" +任务名称: {self.task_name.text()} +数据集目录: {self.dataset_dir.text()} +起始集索引: {self.episode_idx.value()} +最大时间步: {self.max_timesteps.value()} +集数: {self.num_episodes.value()} +帧率: {self.frame_rate.value()} +使用深度图像: {"是" if self.use_depth_image.isChecked() else "否"} +使用机器人底盘: {"是" if self.use_robot_base.isChecked() else "否"} +相机: {', '.join(selected_cameras)} + """ + + self.config_text.setText(config_text) + self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡 + + QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集") + + def start_collection(self): + if not self.task_name.text(): + QMessageBox.warning(self, "配置错误", "请输入有效的任务名称") + return + + # 构建参数 + args = argparse.Namespace( + dataset_dir=self.dataset_dir.text(), + task_name=self.task_name.text(), + episode_idx=self.episode_idx.value(), + max_timesteps=self.max_timesteps.value(), + num_episodes=self.num_episodes.value(), + camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()], + img_front_topic=self.img_front_topic.text(), + img_left_topic=self.img_left_topic.text(), + img_right_topic=self.img_right_topic.text(), + img_front_depth_topic=self.img_front_depth_topic.text(), + img_left_depth_topic=self.img_left_depth_topic.text(), + img_right_depth_topic=self.img_right_depth_topic.text(), + master_arm_left_topic=self.master_arm_left_topic.text(), + master_arm_right_topic=self.master_arm_right_topic.text(), + puppet_arm_left_topic=self.puppet_arm_left_topic.text(), + puppet_arm_right_topic=self.puppet_arm_right_topic.text(), + robot_base_topic=self.robot_base_topic.text(), + use_robot_base=self.use_robot_base.isChecked(), + use_depth_image=self.use_depth_image.isChecked(), + frame_rate=self.frame_rate.value() + ) + + # 更新UI状态 + self.start_button.setEnabled(False) + self.stop_button.setEnabled(True) + self.progress_bar.setValue(0) + self.log_text.clear() + self.log_text.append("正在初始化数据收集...\n") + + # 创建并启动线程 + self.collection_thread = DataCollectionThread(args) + self.collection_thread.update_signal.connect(self.update_log) + self.collection_thread.progress_signal.connect(self.update_progress) + self.collection_thread.finish_signal.connect(self.collection_finished) + self.collection_thread.start() + + def stop_collection(self): + if self.collection_thread and self.collection_thread.isRunning(): + self.log_text.append("正在停止数据收集...\n") + self.collection_thread.stop() + self.collection_thread = None + + self.start_button.setEnabled(True) + self.stop_button.setEnabled(False) + + @pyqtSlot(str) + def update_log(self, message): + self.log_text.append(message) + # 自动滚动到底部 + cursor = self.log_text.textCursor() + cursor.movePosition(QTextCursor.End) + self.log_text.setTextCursor(cursor) + + @pyqtSlot(int) + def update_progress(self, value): + self.progress_bar.setValue(value) + + @pyqtSlot(bool, str) + def collection_finished(self, success, message): + self.start_button.setEnabled(True) + self.stop_button.setEnabled(False) + + if success: + QMessageBox.information(self, "完成", message) + else: + QMessageBox.warning(self, "出错", f"数据收集失败: {message}") + + # 更新episode_idx值 + if success and self.num_episodes.value() > 0: + self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value()) + +def main(): + app = QApplication(sys.argv) + window = AlohaDataCollectionGUI() + window.show() + sys.exit(app.exec_()) + +if __name__ == "__main__": + main() diff --git a/collect_data/collect_data_lerobot.py b/collect_data/collect_data_lerobot.py deleted file mode 100644 index e880e54..0000000 --- a/collect_data/collect_data_lerobot.py +++ /dev/null @@ -1,462 +0,0 @@ -import logging -import time -from dataclasses import asdict -from pprint import pformat -from pprint import pprint - -# from safetensors.torch import load_file, save_file -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.factory import make_policy -from lerobot.common.robot_devices.control_configs import ( - CalibrateControlConfig, - ControlPipelineConfig, - RecordControlConfig, - RemoteRobotConfig, - ReplayControlConfig, - TeleoperateControlConfig, -) -from lerobot.common.robot_devices.control_utils import ( - # init_keyboard_listener, - record_episode, - stop_recording, - is_headless -) -from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config -from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect -from lerobot.common.utils.utils import has_method, init_logging, log_say -from lerobot.common.utils.utils import get_safe_torch_device -from contextlib import nullcontext -from copy import copy -import torch -import rospy -import cv2 -from lerobot.configs import parser -from agilex_robot import AgilexRobot - - -######################################################################################## -# Control modes -######################################################################################## - - -def predict_action(observation, policy, device, use_amp): - observation = copy(observation) - with ( - torch.inference_mode(), - torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), - ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension - for name in observation: - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - # Compute the next action with the policy - # based on the current observation - action = policy.select_action(observation) - - # Remove batch dimension - action = action.squeeze(0) - - # Move to cpu, if not already the case - action = action.to("cpu") - - return action - -def control_loop( - robot, - control_time_s=None, - teleoperate=False, - display_cameras=False, - dataset: LeRobotDataset | None = None, - events=None, - policy = None, - fps: int | None = None, - single_task: str | None = None, -): - # TODO(rcadene): Add option to record logs - # if not robot.is_connected: - # robot.connect() - - if events is None: - events = {"exit_early": False} - - if control_time_s is None: - control_time_s = float("inf") - - if dataset is not None and single_task is None: - raise ValueError("You need to provide a task as argument in `single_task`.") - - if dataset is not None and fps is not None and dataset.fps != fps: - raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") - - timestamp = 0 - start_episode_t = time.perf_counter() - rate = rospy.Rate(fps) - print_flag = True - while timestamp < control_time_s and not rospy.is_shutdown(): - # print(timestamp < control_time_s) - # print(rospy.is_shutdown()) - start_loop_t = time.perf_counter() - - if teleoperate: - observation, action = robot.teleop_step() - if observation is None or action is None: - if print_flag: - print("sync data fail, retrying...\n") - print_flag = False - rate.sleep() - continue - else: - # pass - observation = robot.capture_observation() - if policy is not None: - pred_action = predict_action( - observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp - ) - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - action = robot.send_action(pred_action) - action = {"action": action} - - if dataset is not None: - frame = {**observation, **action, "task": single_task} - dataset.add_frame(frame) - - # if display_cameras and not is_headless(): - # image_keys = [key for key in observation if "image" in key] - # for key in image_keys: - # if "depth" in key: - # pass - # else: - # cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - - # print(1) - # cv2.waitKey(1) - - if display_cameras and not is_headless(): - image_keys = [key for key in observation if "image" in key] - - # 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整) - screen_width = 1920 - screen_height = 1080 - - # 计算窗口的排列方式 - num_images = len(image_keys) - max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640 - rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数 - columns = min(num_images, max_columns) # 实际使用的列数 - - # 遍历所有图像键并显示 - for idx, key in enumerate(image_keys): - if "depth" in key: - continue # 跳过深度图像 - - # 将图像从 RGB 转换为 BGR 格式 - image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) - - # 创建窗口 - cv2.imshow(key, image) - - # 计算窗口位置 - window_width = 640 - window_height = 480 - row = idx // max_columns - col = idx % max_columns - x_position = col * window_width - y_position = row * window_height - - # 移动窗口到指定位置 - cv2.moveWindow(key, x_position, y_position) - - # 等待 1 毫秒以处理事件 - cv2.waitKey(1) - - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - # log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_episode_t - if events["exit_early"]: - events["exit_early"] = False - break - - -def init_keyboard_listener(): - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. - events = {} - events["exit_early"] = False - events["record_start"] = False - events["rerecord_episode"] = False - events["stop_recording"] = False - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - listener = None - return listener, events - - # Only import pynput if not in a headless environment - from pynput import keyboard - - def on_press(key): - try: - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - events["exit_early"] = True - events["record_start"] = False - elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") - events["stop_recording"] = True - events["exit_early"] = True - elif key == keyboard.Key.up: - print("Up arrow pressed. Start data recording...") - events["record_start"] = True - - - except Exception as e: - print(f"Error handling key press: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - return listener, events - - -def stop_recording(robot, listener, display_cameras): - - if not is_headless(): - if listener is not None: - listener.stop() - - if display_cameras: - cv2.destroyAllWindows() - - -def record_episode( - robot, - dataset, - events, - episode_time_s, - display_cameras, - policy, - fps, - single_task, -): - control_loop( - robot=robot, - control_time_s=episode_time_s, - display_cameras=display_cameras, - dataset=dataset, - events=events, - policy=policy, - fps=fps, - teleoperate=policy is None, - single_task=single_task, - ) - - -def record( - robot, - cfg -) -> LeRobotDataset: - # TODO(rcadene): Add option to record logs - if cfg.resume: - dataset = LeRobotDataset( - cfg.repo_id, - root=cfg.root, - ) - if len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.num_image_writer_processes, - num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), - ) - # sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) - else: - # Create empty dataset or load existing saved episodes - # sanity_check_dataset_name(cfg.repo_id, cfg.policy) - dataset = LeRobotDataset.create( - cfg.repo_id, - cfg.fps, - root=cfg.root, - robot=None, - features=robot.features, - use_videos=cfg.video, - image_writer_processes=cfg.num_image_writer_processes, - image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), - ) - - # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - # policy = None - - # if not robot.is_connected: - # robot.connect() - - listener, events = init_keyboard_listener() - - # Execute a few seconds without recording to: - # 1. teleoperate the robot to move it in starting position if no policy provided, - # 2. give times to the robot devices to connect and start synchronizing, - # 3. place the cameras windows on screen - enable_teleoperation = policy is None - log_say("Warmup record", cfg.play_sounds) - print() - print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n") - # warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) - - # if has_method(robot, "teleop_safety_stop"): - # robot.teleop_safety_stop() - - recorded_episodes = 0 - while True: - if recorded_episodes >= cfg.num_episodes: - break - - # if events["record_start"]: - log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) - pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}") - record_episode( - robot=robot, - dataset=dataset, - events=events, - episode_time_s=cfg.episode_time_s, - display_cameras=cfg.display_cameras, - policy=policy, - fps=cfg.fps, - single_task=cfg.single_task, - ) - - # Execute a few seconds without recording to give time to manually reset the environment - # Current code logic doesn't allow to teleoperate during this time. - # TODO(rcadene): add an option to enable teleoperation during reset - # Skip reset for the last episode to be recorded - if not events["stop_recording"] and ( - (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment", cfg.play_sounds) - pprint("Reset the environment, stop recording") - # reset_environment(robot, events, cfg.reset_time_s, cfg.fps) - - if events["rerecord_episode"]: - log_say("Re-record episode", cfg.play_sounds) - pprint("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded_episodes += 1 - - if events["stop_recording"]: - break - - log_say("Stop recording", cfg.play_sounds, blocking=True) - stop_recording(robot, listener, cfg.display_cameras) - - if cfg.push_to_hub: - dataset.push_to_hub(tags=cfg.tags, private=cfg.private) - - log_say("Exiting", cfg.play_sounds) - return dataset - - -def replay( - robot: AgilexRobot, - cfg, -): - # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset - # TODO(rcadene): Add option to record logs - - dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) - actions = dataset.hf_dataset.select_columns("action") - - # if not robot.is_connected: - # robot.connect() - - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() - - action = actions[idx]["action"] - robot.send_action(action) - - dt_s = time.perf_counter() - start_episode_t - busy_wait(1 / cfg.fps - dt_s) - - dt_s = time.perf_counter() - start_episode_t - # log_control_info(robot, dt_s, fps=cfg.fps) - - -import argparse -def get_arguments(): - parser = argparse.ArgumentParser() - args = parser.parse_args() - args.fps = 30 - args.resume = False - args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" - args.root = "./data5" - args.episode = 0 # replay episode - args.num_image_writer_processes = 0 - args.num_image_writer_threads_per_camera = 4 - args.video = True - args.num_episodes = 100 - args.episode_time_s = 30000 - args.play_sounds = False - args.display_cameras = True - args.single_task = "move the bottle from the right to the scale right" - args.use_depth_image = False - args.use_base = False - args.push_to_hub = False - args.policy = None - # args.teleoprate = True - args.control_type = "record" - # args.control_type = "replay" - return args - - - - -# @parser.wrap() -def control_robot(cfg): - - from rosrobot_factory import RobotFactory - # 使用工厂模式创建机器人实例 - robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg) - - if cfg.control_type == "record": - record(robot, cfg) - elif cfg.control_type == "replay": - replay(robot, cfg) - - -if __name__ == "__main__": - cfg = get_arguments() - control_robot(cfg) - # control_robot() - # 使用工厂模式创建机器人实例 - # robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) - # print(robot.features.items()) - # print([key for key, ft in robot.features.items() if ft["dtype"] == "video"]) - # record(robot, cfg) - # capture = robot.capture_observation() - # import torch - # torch.save(capture, "test.pt") - # action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094, - # 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]], - # device='cpu') - # robot.send_action(action.squeeze(0)) - # print() \ No newline at end of file diff --git a/collect_data/export_env.bash b/collect_data/export_env.bash deleted file mode 100644 index 58be2a7..0000000 --- a/collect_data/export_env.bash +++ /dev/null @@ -1,3 +0,0 @@ -export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5 -# export LD_LIBRARY_PATH=/home/ubuntu/miniconda3/envs/lerobot/lib:$LD_LIBRARY_PATH -export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH diff --git a/collect_data/inference.py b/collect_data/inference.py deleted file mode 100644 index 34f7f52..0000000 --- a/collect_data/inference.py +++ /dev/null @@ -1,769 +0,0 @@ -#!/home/lin/software/miniconda3/envs/aloha/bin/python -# -- coding: UTF-8 -""" -#!/usr/bin/python3 -""" - -import torch -import numpy as np -import os -import pickle -import argparse -from einops import rearrange -import collections -from collections import deque - -import rospy -from std_msgs.msg import Header -from geometry_msgs.msg import Twist -from sensor_msgs.msg import JointState, Image -from nav_msgs.msg import Odometry -from cv_bridge import CvBridge -import time -import threading -import math -import threading - - - - -import sys -sys.path.append("./") - -SEED = 42 -torch.manual_seed(SEED) -np.random.seed(SEED) - -task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']} - -inference_thread = None -inference_lock = threading.Lock() -inference_actions = None -inference_timestep = None - - -def actions_interpolation(args, pre_action, actions, stats): - steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0) - pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] - post_process = lambda a: a * stats['action_std'] + stats['action_mean'] - result = [pre_action] - post_action = post_process(actions[0]) - # print("pre_action:", pre_action[7:]) - # print("actions_interpolation1:", post_action[:, 7:]) - max_diff_index = 0 - max_diff = -1 - for i in range(post_action.shape[0]): - diff = 0 - for j in range(pre_action.shape[0]): - if j == 6 or j == 13: - continue - diff += math.fabs(pre_action[j] - post_action[i][j]) - if diff > max_diff: - max_diff = diff - max_diff_index = i - - for i in range(max_diff_index, post_action.shape[0]): - step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])]) - inter = np.linspace(result[-1], post_action[i], step+2) - result.extend(inter[1:]) - while len(result) < args.chunk_size+1: - result.append(result[-1]) - result = np.array(result)[1:args.chunk_size+1] - # print("actions_interpolation2:", result.shape, result[:, 7:]) - result = pre_process(result) - result = result[np.newaxis, :] - return result - - -def get_model_config(args): - # 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。 - set_seed(1) - - # 如果是ACT策略 - # fixed parameters - if args.policy_class == 'ACT': - policy_config = {'lr': args.lr, - 'lr_backbone': args.lr_backbone, - 'backbone': args.backbone, - 'masks': args.masks, - 'weight_decay': args.weight_decay, - 'dilation': args.dilation, - 'position_embedding': args.position_embedding, - 'loss_function': args.loss_function, - 'chunk_size': args.chunk_size, # 查询 - 'camera_names': task_config['camera_names'], - 'use_depth_image': args.use_depth_image, - 'use_robot_base': args.use_robot_base, - 'kl_weight': args.kl_weight, # kl散度权重 - 'hidden_dim': args.hidden_dim, # 隐藏层维度 - 'dim_feedforward': args.dim_feedforward, - 'enc_layers': args.enc_layers, - 'dec_layers': args.dec_layers, - 'nheads': args.nheads, - 'dropout': args.dropout, - 'pre_norm': args.pre_norm - } - elif args.policy_class == 'CNNMLP': - policy_config = {'lr': args.lr, - 'lr_backbone': args.lr_backbone, - 'backbone': args.backbone, - 'masks': args.masks, - 'weight_decay': args.weight_decay, - 'dilation': args.dilation, - 'position_embedding': args.position_embedding, - 'loss_function': args.loss_function, - 'chunk_size': 1, # 查询 - 'camera_names': task_config['camera_names'], - 'use_depth_image': args.use_depth_image, - 'use_robot_base': args.use_robot_base - } - - elif args.policy_class == 'Diffusion': - policy_config = {'lr': args.lr, - 'lr_backbone': args.lr_backbone, - 'backbone': args.backbone, - 'masks': args.masks, - 'weight_decay': args.weight_decay, - 'dilation': args.dilation, - 'position_embedding': args.position_embedding, - 'loss_function': args.loss_function, - 'chunk_size': args.chunk_size, # 查询 - 'camera_names': task_config['camera_names'], - 'use_depth_image': args.use_depth_image, - 'use_robot_base': args.use_robot_base, - 'observation_horizon': args.observation_horizon, - 'action_horizon': args.action_horizon, - 'num_inference_timesteps': args.num_inference_timesteps, - 'ema_power': args.ema_power - } - else: - raise NotImplementedError - - config = { - 'ckpt_dir': args.ckpt_dir, - 'ckpt_name': args.ckpt_name, - 'ckpt_stats_name': args.ckpt_stats_name, - 'episode_len': args.max_publish_step, - 'state_dim': args.state_dim, - 'policy_class': args.policy_class, - 'policy_config': policy_config, - 'temporal_agg': args.temporal_agg, - 'camera_names': task_config['camera_names'], - } - return config - - -def make_policy(policy_class, policy_config): - if policy_class == 'ACT': - policy = ACTPolicy(policy_config) - elif policy_class == 'CNNMLP': - policy = CNNMLPPolicy(policy_config) - elif policy_class == 'Diffusion': - policy = DiffusionPolicy(policy_config) - else: - raise NotImplementedError - return policy - - -def get_image(observation, camera_names): - curr_images = [] - for cam_name in camera_names: - curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w') - - curr_images.append(curr_image) - curr_image = np.stack(curr_images, axis=0) - curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) - return curr_image - - -def get_depth_image(observation, camera_names): - curr_images = [] - for cam_name in camera_names: - curr_images.append(observation['images_depth'][cam_name]) - curr_image = np.stack(curr_images, axis=0) - curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) - return curr_image - - -def inference_process(args, config, ros_operator, policy, stats, t, pre_action): - global inference_lock - global inference_actions - global inference_timestep - print_flag = True - pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] - pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"] - rate = rospy.Rate(args.publish_rate) - while True and not rospy.is_shutdown(): - result = ros_operator.get_frame() - if not result: - if print_flag: - print("syn fail") - print_flag = False - rate.sleep() - continue - print_flag = True - (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, - puppet_arm_left, puppet_arm_right, robot_base) = result - obs = collections.OrderedDict() - image_dict = dict() - - image_dict[config['camera_names'][0]] = img_front - image_dict[config['camera_names'][1]] = img_left - image_dict[config['camera_names'][2]] = img_right - - - obs['images'] = image_dict - - if args.use_depth_image: - image_depth_dict = dict() - image_depth_dict[config['camera_names'][0]] = img_front_depth - image_depth_dict[config['camera_names'][1]] = img_left_depth - image_depth_dict[config['camera_names'][2]] = img_right_depth - obs['images_depth'] = image_depth_dict - - obs['qpos'] = np.concatenate( - (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0) - obs['qvel'] = np.concatenate( - (np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0) - obs['effort'] = np.concatenate( - (np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0) - if args.use_robot_base: - obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z] - obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0) - else: - obs['base_vel'] = [0.0, 0.0] - # qpos_numpy = np.array(obs['qpos']) - - # 归一化处理qpos 并转到cuda - qpos = pre_pos_process(obs['qpos']) - qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) - # 当前图像curr_image获取图像 - curr_image = get_image(obs, config['camera_names']) - curr_depth_image = None - if args.use_depth_image: - curr_depth_image = get_depth_image(obs, config['camera_names']) - start_time = time.time() - all_actions = policy(curr_image, curr_depth_image, qpos) - end_time = time.time() - print("model cost time: ", end_time -start_time) - inference_lock.acquire() - inference_actions = all_actions.cpu().detach().numpy() - if pre_action is None: - pre_action = obs['qpos'] - # print("obs['qpos']:", obs['qpos'][7:]) - if args.use_actions_interpolation: - inference_actions = actions_interpolation(args, pre_action, inference_actions, stats) - inference_timestep = t - inference_lock.release() - break - - -def model_inference(args, config, ros_operator, save_episode=True): - global inference_lock - global inference_actions - global inference_timestep - global inference_thread - set_seed(1000) - - # 1 创建模型数据 继承nn.Module - policy = make_policy(config['policy_class'], config['policy_config']) - # print("model structure\n", policy.model) - - # 2 加载模型权重 - ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name']) - state_dict = torch.load(ckpt_path) - new_state_dict = {} - for key, value in state_dict.items(): - if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]: - continue - if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]: - continue - new_state_dict[key] = value - loading_status = policy.deserialize(new_state_dict) - if not loading_status: - print("ckpt path not exist") - return False - - # 3 模型设置为cuda模式和验证模式 - policy.cuda() - policy.eval() - - # 4 加载统计值 - stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name']) - # 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维 - with open(stats_path, 'rb') as f: - stats = pickle.load(f) - - # 数据预处理和后处理函数定义 - pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] - post_process = lambda a: a * stats['action_std'] + stats['action_mean'] - - max_publish_step = config['episode_len'] - chunk_size = config['policy_config']['chunk_size'] - - # 发布基础的姿态 - left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875] - right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875] - left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] - right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883] - - ros_operator.puppet_arm_publish_continuous(left0, right0) - input("Enter any key to continue :") - ros_operator.puppet_arm_publish_continuous(left1, right1) - action = None - # 推理 - with torch.inference_mode(): - while True and not rospy.is_shutdown(): - # 每个回合的步数 - t = 0 - max_t = 0 - rate = rospy.Rate(args.publish_rate) - if config['temporal_agg']: - all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']]) - while t < max_publish_step and not rospy.is_shutdown(): - # start_time = time.time() - # query policy - if config['policy_class'] == "ACT": - if t >= max_t: - pre_action = action - inference_thread = threading.Thread(target=inference_process, - args=(args, config, ros_operator, - policy, stats, t, pre_action)) - inference_thread.start() - inference_thread.join() - inference_lock.acquire() - if inference_actions is not None: - inference_thread = None - all_actions = inference_actions - inference_actions = None - max_t = t + args.pos_lookahead_step - if config['temporal_agg']: - all_time_actions[[t], t:t + chunk_size] = all_actions - inference_lock.release() - if config['temporal_agg']: - actions_for_curr_step = all_time_actions[:, t] - actions_populated = np.all(actions_for_curr_step != 0, axis=1) - actions_for_curr_step = actions_for_curr_step[actions_populated] - k = 0.01 - exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - exp_weights = exp_weights / exp_weights.sum() - exp_weights = exp_weights[:, np.newaxis] - raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True) - else: - if args.pos_lookahead_step != 0: - raw_action = all_actions[:, t % args.pos_lookahead_step] - else: - raw_action = all_actions[:, t % chunk_size] - else: - raise NotImplementedError - action = post_process(raw_action[0]) - left_action = action[:7] # 取7维度 - right_action = action[7:14] - ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread - if args.use_robot_base: - vel_action = action[14:16] - ros_operator.robot_base_publish(vel_action) - t += 1 - # end_time = time.time() - # print("publish: ", t) - # print("time:", end_time - start_time) - # print("left_action:", left_action) - # print("right_action:", right_action) - rate.sleep() - - -class RosOperator: - def __init__(self, args): - self.robot_base_deque = None - self.puppet_arm_right_deque = None - self.puppet_arm_left_deque = None - self.img_front_deque = None - self.img_right_deque = None - self.img_left_deque = None - self.img_front_depth_deque = None - self.img_right_depth_deque = None - self.img_left_depth_deque = None - self.bridge = None - self.puppet_arm_left_publisher = None - self.puppet_arm_right_publisher = None - self.robot_base_publisher = None - self.puppet_arm_publish_thread = None - self.puppet_arm_publish_lock = None - self.args = args - self.ctrl_state = False - self.ctrl_state_lock = threading.Lock() - self.init() - self.init_ros() - - def init(self): - self.bridge = CvBridge() - self.img_left_deque = deque() - self.img_right_deque = deque() - self.img_front_deque = deque() - self.img_left_depth_deque = deque() - self.img_right_depth_deque = deque() - self.img_front_depth_deque = deque() - self.puppet_arm_left_deque = deque() - self.puppet_arm_right_deque = deque() - self.robot_base_deque = deque() - self.puppet_arm_publish_lock = threading.Lock() - self.puppet_arm_publish_lock.acquire() - - def puppet_arm_publish(self, left, right): - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 - joint_state_msg.position = left - self.puppet_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = right - self.puppet_arm_right_publisher.publish(joint_state_msg) - - def robot_base_publish(self, vel): - vel_msg = Twist() - vel_msg.linear.x = vel[0] - vel_msg.linear.y = 0 - vel_msg.linear.z = 0 - vel_msg.angular.x = 0 - vel_msg.angular.y = 0 - vel_msg.angular.z = vel[1] - self.robot_base_publisher.publish(vel_msg) - - def puppet_arm_publish_continuous(self, left, right): - rate = rospy.Rate(self.args.publish_rate) - left_arm = None - right_arm = None - while True and not rospy.is_shutdown(): - if len(self.puppet_arm_left_deque) != 0: - left_arm = list(self.puppet_arm_left_deque[-1].position) - if len(self.puppet_arm_right_deque) != 0: - right_arm = list(self.puppet_arm_right_deque[-1].position) - if left_arm is None or right_arm is None: - rate.sleep() - continue - else: - break - left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))] - right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))] - flag = True - step = 0 - while flag and not rospy.is_shutdown(): - if self.puppet_arm_publish_lock.acquire(False): - return - left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))] - right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))] - flag = False - for i in range(len(left)): - if left_diff[i] < self.args.arm_steps_length[i]: - left_arm[i] = left[i] - else: - left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] - flag = True - for i in range(len(right)): - if right_diff[i] < self.args.arm_steps_length[i]: - right_arm[i] = right[i] - else: - right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i] - flag = True - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 - joint_state_msg.position = left_arm - self.puppet_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = right_arm - self.puppet_arm_right_publisher.publish(joint_state_msg) - step += 1 - print("puppet_arm_publish_continuous:", step) - rate.sleep() - - def puppet_arm_publish_linear(self, left, right): - num_step = 100 - rate = rospy.Rate(200) - - left_arm = None - right_arm = None - - while True and not rospy.is_shutdown(): - if len(self.puppet_arm_left_deque) != 0: - left_arm = list(self.puppet_arm_left_deque[-1].position) - if len(self.puppet_arm_right_deque) != 0: - right_arm = list(self.puppet_arm_right_deque[-1].position) - if left_arm is None or right_arm is None: - rate.sleep() - continue - else: - break - - traj_left_list = np.linspace(left_arm, left, num_step) - traj_right_list = np.linspace(right_arm, right, num_step) - - for i in range(len(traj_left_list)): - traj_left = traj_left_list[i] - traj_right = traj_right_list[i] - traj_left[-1] = left[-1] - traj_right[-1] = right[-1] - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 - joint_state_msg.position = traj_left - self.puppet_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = traj_right - self.puppet_arm_right_publisher.publish(joint_state_msg) - rate.sleep() - - def puppet_arm_publish_continuous_thread(self, left, right): - if self.puppet_arm_publish_thread is not None: - self.puppet_arm_publish_lock.release() - self.puppet_arm_publish_thread.join() - self.puppet_arm_publish_lock.acquire(False) - self.puppet_arm_publish_thread = None - self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right)) - self.puppet_arm_publish_thread.start() - - def get_frame(self): - if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \ - (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)): - return False - if self.args.use_depth_image: - frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(), - self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()]) - else: - frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()]) - - if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time: - return False - if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): - return False - if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): - return False - if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): - return False - if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): - return False - - while self.img_left_deque[0].header.stamp.to_sec() < frame_time: - self.img_left_deque.popleft() - img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough') - - while self.img_right_deque[0].header.stamp.to_sec() < frame_time: - self.img_right_deque.popleft() - img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough') - - while self.img_front_deque[0].header.stamp.to_sec() < frame_time: - self.img_front_deque.popleft() - img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough') - - while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: - self.puppet_arm_left_deque.popleft() - puppet_arm_left = self.puppet_arm_left_deque.popleft() - - while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: - self.puppet_arm_right_deque.popleft() - puppet_arm_right = self.puppet_arm_right_deque.popleft() - - img_left_depth = None - if self.args.use_depth_image: - while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: - self.img_left_depth_deque.popleft() - img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough') - - img_right_depth = None - if self.args.use_depth_image: - while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: - self.img_right_depth_deque.popleft() - img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough') - - img_front_depth = None - if self.args.use_depth_image: - while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: - self.img_front_depth_deque.popleft() - img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough') - - robot_base = None - if self.args.use_robot_base: - while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: - self.robot_base_deque.popleft() - robot_base = self.robot_base_deque.popleft() - - return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, - puppet_arm_left, puppet_arm_right, robot_base) - - def img_left_callback(self, msg): - if len(self.img_left_deque) >= 2000: - self.img_left_deque.popleft() - self.img_left_deque.append(msg) - - def img_right_callback(self, msg): - if len(self.img_right_deque) >= 2000: - self.img_right_deque.popleft() - self.img_right_deque.append(msg) - - def img_front_callback(self, msg): - if len(self.img_front_deque) >= 2000: - self.img_front_deque.popleft() - self.img_front_deque.append(msg) - - def img_left_depth_callback(self, msg): - if len(self.img_left_depth_deque) >= 2000: - self.img_left_depth_deque.popleft() - self.img_left_depth_deque.append(msg) - - def img_right_depth_callback(self, msg): - if len(self.img_right_depth_deque) >= 2000: - self.img_right_depth_deque.popleft() - self.img_right_depth_deque.append(msg) - - def img_front_depth_callback(self, msg): - if len(self.img_front_depth_deque) >= 2000: - self.img_front_depth_deque.popleft() - self.img_front_depth_deque.append(msg) - - def puppet_arm_left_callback(self, msg): - if len(self.puppet_arm_left_deque) >= 2000: - self.puppet_arm_left_deque.popleft() - self.puppet_arm_left_deque.append(msg) - - def puppet_arm_right_callback(self, msg): - if len(self.puppet_arm_right_deque) >= 2000: - self.puppet_arm_right_deque.popleft() - self.puppet_arm_right_deque.append(msg) - - def robot_base_callback(self, msg): - if len(self.robot_base_deque) >= 2000: - self.robot_base_deque.popleft() - self.robot_base_deque.append(msg) - - def ctrl_callback(self, msg): - self.ctrl_state_lock.acquire() - self.ctrl_state = msg.data - self.ctrl_state_lock.release() - - def get_ctrl_state(self): - self.ctrl_state_lock.acquire() - state = self.ctrl_state - self.ctrl_state_lock.release() - return state - - def init_ros(self): - rospy.init_node('joint_state_publisher', anonymous=True) - rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True) - if self.args.use_depth_image: - rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True) - self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10) - self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10) - self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10) - - -def get_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) - parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False) - parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False) - parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False) - parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False) - parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False) - parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False) - parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False) - parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False) - parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False) - parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False) - parser.add_argument('--dilation', action='store_true', - help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False) - parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), - help="Type of positional embedding to use on top of the image features", required=False) - parser.add_argument('--masks', action='store_true', - help="Train segmentation head if the flag is provided") - parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False) - parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False) - parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False) - parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False) - - parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False) - parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False) - parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False) - parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False) - parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False) - parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False) - parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False) - parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False) - parser.add_argument('--pre_norm', action='store_true', required=False) - - parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', - default='/camera_f/color/image_raw', required=False) - parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', - default='/camera_l/color/image_raw', required=False) - parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', - default='/camera_r/color/image_raw', required=False) - - parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic', - default='/camera_f/depth/image_raw', required=False) - parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic', - default='/camera_l/depth/image_raw', required=False) - parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic', - default='/camera_r/depth/image_raw', required=False) - - parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic', - default='/master/joint_left', required=False) - parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic', - default='/master/joint_right', required=False) - parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', - default='/puppet/joint_left', required=False) - parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', - default='/puppet/joint_right', required=False) - - parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', - default='/odom_raw', required=False) - parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic', - default='/cmd_vel', required=False) - parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', - default=False, required=False) - parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate', - default=40, required=False) - parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step', - default=0, required=False) - parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', - default=32, required=False) - parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length', - default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False) - - parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation', - default=False, required=False) - parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image', - default=False, required=False) - - # for Diffusion - parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False) - parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False) - parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False) - parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False) - args = parser.parse_args() - return args - - -def main(): - args = get_arguments() - ros_operator = RosOperator(args) - config = get_model_config(args) - model_inference(args, config, ros_operator, save_episode=True) - - -if __name__ == '__main__': - main() -# python act/inference.py --ckpt_dir ~/train0314/ \ No newline at end of file diff --git a/collect_data/read_parquet.py b/collect_data/read_parquet.py deleted file mode 100644 index 577a1e3..0000000 --- a/collect_data/read_parquet.py +++ /dev/null @@ -1,33 +0,0 @@ -import pandas as pd - -def read_and_print_parquet_row(file_path, row_index=0): - """ - 读取Parquet文件并打印指定行的数据 - - 参数: - file_path (str): Parquet文件路径 - row_index (int): 要打印的行索引(默认为第0行) - """ - try: - # 读取Parquet文件 - df = pd.read_parquet(file_path) - - # 检查行索引是否有效 - if row_index >= len(df): - print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)") - return - - # 打印指定行数据 - print(f"文件: {file_path}") - print(f"第 {row_index} 行数据:\n{'-'*30}") - print(df.iloc[row_index]) - - except FileNotFoundError: - print(f"错误: 文件 {file_path} 不存在") - except Exception as e: - print(f"读取失败: {str(e)}") - -# 示例用法 -if __name__ == "__main__": - file_path = "example.parquet" # 替换为你的Parquet文件路径 - read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行 diff --git a/collect_data/replay_data.py b/collect_data/replay_data.py old mode 100644 new mode 100755 index 6c880dc..86ce344 --- a/collect_data/replay_data.py +++ b/collect_data/replay_data.py @@ -10,24 +10,63 @@ from cv_bridge import CvBridge from std_msgs.msg import Header from sensor_msgs.msg import Image, JointState from geometry_msgs.msg import Twist -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +import sys +sys.path.append("./") +def load_hdf5(dataset_dir, dataset_name): + dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5') + if not os.path.isfile(dataset_path): + print(f'Dataset does not exist at \n{dataset_path}\n') + exit() + with h5py.File(dataset_path, 'r') as root: + is_sim = root.attrs['sim'] + compressed = root.attrs.get('compress', False) + qpos = root['/observations/qpos'][()] + qvel = root['/observations/qvel'][()] + if 'effort' in root.keys(): + effort = root['/observations/effort'][()] + else: + effort = None + action = root['/action'][()] + base_action = root['/base_action'][()] + + image_dict = dict() + for cam_name in root[f'/observations/images/'].keys(): + image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] + + if compressed: + compress_len = root['/compress_len'][()] + + if compressed: + for cam_id, cam_name in enumerate(image_dict.keys()): + # un-pad and uncompress + padded_compressed_image_list = image_dict[cam_name] + image_list = [] + for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory + image_len = int(compress_len[cam_id, frame_id]) + + compressed_image = padded_compressed_image + image = cv2.imdecode(compressed_image, 1) + image_list.append(image) + image_dict[cam_name] = image_list + + return qpos, qvel, effort, action, base_action, image_dict def main(args): rospy.init_node("replay_node") bridge = CvBridge() - # img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10) - # img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10) - # img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10) + img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10) + img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10) + img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10) - # puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10) - # puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10) + puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10) + puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10) master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10) master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10) - # robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10) + robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10) # dataset_dir = args.dataset_dir @@ -35,78 +74,130 @@ def main(args): # task_name = args.task_name # dataset_name = f'episode_{episode_idx}' - dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode]) - actions = dataset.hf_dataset.select_columns("action") - velocitys = dataset.hf_dataset.select_columns("observation.velocity") - efforts = dataset.hf_dataset.select_columns("observation.effort") - origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279] origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279] + joint_state_msg = JointState() joint_state_msg.header = Header() - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 twist_msg = Twist() - - rate = rospy.Rate(args.fps) + + rate = rospy.Rate(args.frame_rate) # qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name) + actions = [[ 0.1277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.3238, -0.1094, + 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]] - last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647] - last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] - last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] - rate = rospy.Rate(50) - for idx in range(len(actions)): - action = actions[idx]['action'].detach().cpu().numpy() - velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy() - effort = efforts[idx]['observation.effort'].detach().cpu().numpy() - if(rospy.is_shutdown()): - break - - new_actions = np.linspace(last_action, action, 5) # 插值 - new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值 - new_efforts = np.linspace(last_effort, effort, 5) # 插值 - last_action = action - last_velocity = velocity - last_effort = effort - for act in new_actions: - print(np.round(act[:7], 4)) - cur_timestamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.header.stamp = cur_timestamp + if not args.only_pub_master: + last_action = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279, 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279] + rate = rospy.Rate(100) + for action in actions: + if(rospy.is_shutdown()): + break - joint_state_msg.position = act[:7] - joint_state_msg.velocity = last_velocity[:7] - joint_state_msg.effort = last_effort[:7] + new_actions = np.linspace(last_action, action, 50) # 插值 + last_action = action + for act in new_actions: + print(np.round(act[:7], 4)) + cur_timestamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.header.stamp = cur_timestamp + + joint_state_msg.position = act[:7] + master_arm_left_publisher.publish(joint_state_msg) + + joint_state_msg.position = act[7:] + master_arm_right_publisher.publish(joint_state_msg) + + if(rospy.is_shutdown()): + break + rate.sleep() + + else: + i = 0 + while(not rospy.is_shutdown() and i < len(actions)): + print("left: ", np.round(qposs[i][:7], 4), " right: ", np.round(qposs[i][7:], 4)) + + cam_names = [k for k in image_dicts.keys()] + image0 = image_dicts[cam_names[0]][i] + image0 = image0[:, :, [2, 1, 0]] # swap B and R channel + + image1 = image_dicts[cam_names[1]][i] + image1 = image1[:, :, [2, 1, 0]] # swap B and R channel + + image2 = image_dicts[cam_names[2]][i] + image2 = image2[:, :, [2, 1, 0]] # swap B and R channel + + cur_timestamp = rospy.Time.now() # 设置时间戳 + + joint_state_msg.header.stamp = cur_timestamp + joint_state_msg.position = actions[i][:7] master_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = act[7:] - joint_state_msg.velocity = last_velocity[:7] - joint_state_msg.effort = last_effort[7:] + joint_state_msg.position = actions[i][7:] master_arm_right_publisher.publish(joint_state_msg) - if(rospy.is_shutdown()): - break - rate.sleep() + joint_state_msg.position = qposs[i][:7] + puppet_arm_left_publisher.publish(joint_state_msg) + + joint_state_msg.position = qposs[i][7:] + puppet_arm_right_publisher.publish(joint_state_msg) + + img_front_publisher.publish(bridge.cv2_to_imgmsg(image0, "bgr8")) + img_left_publisher.publish(bridge.cv2_to_imgmsg(image1, "bgr8")) + img_right_publisher.publish(bridge.cv2_to_imgmsg(image2, "bgr8")) + twist_msg.linear.x = base_actions[i][0] + twist_msg.angular.z = base_actions[i][1] + robot_base_publisher.publish(twist_msg) + + i += 1 + rate.sleep() + if __name__ == '__main__': parser = argparse.ArgumentParser() - # parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', - # default='/master/joint_left', required=False) - # parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', - # default='/master/joint_right', required=False) - - - args = parser.parse_args() - args.repo_id = "tangger/test" - args.root = "/home/ubuntu/LYT/aloha_lerobot/data1" - args.episode = 1 # replay episode - args.master_arm_left_topic = "/master/joint_left" - args.master_arm_right_topic = "/master/joint_right" - args.fps = 30 + parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=False) + parser.add_argument('--task_name', action='store', type=str, help='Task name.', + default="aloha_mobile_dummy", required=False) + parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False) + + parser.add_argument('--camera_names', action='store', type=str, help='camera_names', + default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False) + + parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', + default='/camera_f/color/image_raw', required=False) + parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', + default='/camera_l/color/image_raw', required=False) + parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', + default='/camera_r/color/image_raw', required=False) + + parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', + default='/master/joint_left', required=False) + parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', + default='/master/joint_right', required=False) + + parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', + default='/puppet/joint_left', required=False) + parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', + default='/puppet/joint_right', required=False) + + parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', + default='/cmd_vel', required=False) + parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', + default=False, required=False) + + parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate', + default=30, required=False) + + parser.add_argument('--only_pub_master', action='store_true', help='only_pub_master',required=False) + + + + args = parser.parse_args() main(args) # python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0 \ No newline at end of file diff --git a/collect_data/requirements.txt b/collect_data/requirements.txt new file mode 100755 index 0000000..a04025d --- /dev/null +++ b/collect_data/requirements.txt @@ -0,0 +1,10 @@ +opencv-python==4.9.0.80 +matplotlib==3.7.5 +h5py==3.8.0 +dm-env==1.6 +numpy==1.23.4 +pyyaml +rospkg==1.5.0 +catkin-pkg==1.0.0 +empy==3.3.4 +PyQt5==5.15.10 diff --git a/collect_data/test.py b/collect_data/test.py deleted file mode 100644 index 8eb8748..0000000 --- a/collect_data/test.py +++ /dev/null @@ -1,70 +0,0 @@ -from lerobot.common.policies.act.modeling_act import ACTPolicy -from lerobot.common.robot_devices.utils import busy_wait -import time -import argparse -from agilex_robot import AgilexRobot -import torch - -def get_arguments(): - parser = argparse.ArgumentParser() - args = parser.parse_args() - args.fps = 30 - args.resume = False - args.repo_id = "tangger/test" - args.root = "./data2" - args.num_image_writer_processes = 0 - args.num_image_writer_threads_per_camera = 4 - args.video = True - args.num_episodes = 50 - args.episode_time_s = 30000 - args.play_sounds = False - args.display_cameras = True - args.single_task = "test test" - args.use_depth_image = False - args.use_base = False - args.push_to_hub = False - args.policy= None - args.teleoprate = False - return args - - -cfg = get_arguments() -robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) -inference_time_s = 360 -fps = 30 -device = "cuda" # TODO: On Mac, use "mps" or "cpu" - -ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model" -policy = ACTPolicy.from_pretrained(ckpt_path) -policy.to(device) - -for _ in range(inference_time_s * fps): - start_time = time.perf_counter() - - # Read the follower state and access the frames from the cameras - observation = robot.capture_observation() - if observation is None: - print("Observation is None, skipping...") - continue - - # Convert to pytorch format: channel first and float32 in [0,1] - # with batch dimension - for name in observation: - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - # Compute the next action with the policy - # based on the current observation - action = policy.select_action(observation) - # Remove batch dimension - action = action.squeeze(0) - # Move to cpu, if not already the case - action = action.to("cpu") - # Order the robot to move - robot.send_action(action) - - dt_s = time.perf_counter() - start_time - busy_wait(1 / fps - dt_s) \ No newline at end of file diff --git a/collect_data/utils.py b/collect_data/utils.py new file mode 100644 index 0000000..08db2f1 --- /dev/null +++ b/collect_data/utils.py @@ -0,0 +1,272 @@ +import cv2 +import numpy as np +import h5py +import time + + + +def display_camera_grid(image_dict, grid_shape=None, window_name="MindRobot-V1 Data Collection", scale=1.0): + """ + 显示多摄像头画面(保持原始比例,但可整体缩放) + + 参数: + image_dict: {摄像头名称: 图像numpy数组} + grid_shape: (行, 列) 布局,None自动计算 + window_name: 窗口名称 + scale: 整体显示缩放比例(0.5表示显示为原尺寸的50%) + """ + # 输入验证和数据处理(保持原代码不变) + if not isinstance(image_dict, dict): + raise TypeError("输入必须是字典类型") + + valid_data = [] + for name, img in image_dict.items(): + if not isinstance(img, np.ndarray): + continue + if img.dtype != np.uint8: + img = img.astype(np.uint8) + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: + img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) + elif img.shape[2] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + valid_data.append((name, img)) + + if not valid_data: + print("错误: 没有有效的图像可显示!") + return None + + # 自动计算网格布局 + num_valid = len(valid_data) + if grid_shape is None: + grid_shape = (1, num_valid) if num_valid <= 3 else (2, int(np.ceil(num_valid/2))) + + rows, cols = grid_shape + + # 计算每行/列的最大尺寸 + row_heights = [0]*rows + col_widths = [0]*cols + + for i, (_, img) in enumerate(valid_data[:rows*cols]): + r, c = i//cols, i%cols + row_heights[r] = max(row_heights[r], img.shape[0]) + col_widths[c] = max(col_widths[c], img.shape[1]) + + # 计算画布总尺寸(应用整体缩放) + canvas_h = int(sum(row_heights) * scale) + canvas_w = int(sum(col_widths) * scale) + + # 创建画布 + canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8) + + # 计算每个子画面的显示区域 + row_pos = [0] + [int(sum(row_heights[:i+1])*scale) for i in range(rows)] + col_pos = [0] + [int(sum(col_widths[:i+1])*scale) for i in range(cols)] + + # 填充图像 + for i, (name, img) in enumerate(valid_data[:rows*cols]): + r, c = i//cols, i%cols + + # 计算当前图像的显示区域 + x1, x2 = col_pos[c], col_pos[c+1] + y1, y2 = row_pos[r], row_pos[r+1] + + # 计算当前图像的缩放后尺寸 + display_h = int(img.shape[0] * scale) + display_w = int(img.shape[1] * scale) + + # 缩放图像(保持比例) + resized_img = cv2.resize(img, (display_w, display_h)) + + # 放置到画布 + canvas[y1:y1+display_h, x1:x1+display_w] = resized_img + + # 添加标签(按比例缩放字体) + font_scale = 0.8 *scale + thickness = max(2, int(2 * scale)) + cv2.putText(canvas, name, (x1+10, y1+30), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness) + + # 显示窗口(自动适应屏幕) + cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) + cv2.imshow(window_name, canvas) + cv2.resizeWindow(window_name, canvas_w, canvas_h) + cv2.waitKey(1) + + return canvas + + +# 保存数据函数 +def save_data(args, timesteps, actions, dataset_path): + # 数据字典 + data_size = len(actions) + data_dict = { + # 一个是奖励里面的qpos,qvel, effort ,一个是实际发的acition + '/observations/qpos': [], + '/observations/qvel': [], + '/observations/effort': [], + '/action': [], + '/base_action': [], + # '/base_action_t265': [], + } + + # 相机字典 观察的图像 + for cam_name in args.camera_names: + data_dict[f'/observations/images/{cam_name}'] = [] + if args.use_depth_image: + data_dict[f'/observations/images_depth/{cam_name}'] = [] + + # len(action): max_timesteps, len(time_steps): max_timesteps + 1 + # 动作长度 遍历动作 + while actions: + # 循环弹出一个队列 + action = actions.pop(0) # 动作 当前动作 + ts = timesteps.pop(0) # 奖励 前一帧 + + # 往字典里面添值 + # Timestep返回的qpos,qvel,effort + data_dict['/observations/qpos'].append(ts.observation['qpos']) + data_dict['/observations/qvel'].append(ts.observation['qvel']) + data_dict['/observations/effort'].append(ts.observation['effort']) + + # 实际发的action + data_dict['/action'].append(action) + data_dict['/base_action'].append(ts.observation['base_vel']) + + # 相机数据 + # data_dict['/base_action_t265'].append(ts.observation['base_vel_t265']) + for cam_name in args.camera_names: + data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name]) + if args.use_depth_image: + data_dict[f'/observations/images_depth/{cam_name}'].append(ts.observation['images_depth'][cam_name]) + + t0 = time.time() + with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root: + # 文本的属性: + # 1 是否仿真 + # 2 图像是否压缩 + # + root.attrs['sim'] = False + root.attrs['compress'] = False + + # 创建一个新的组observations,观测状态组 + # 图像组 + obs = root.create_group('observations') + image = obs.create_group('images') + for cam_name in args.camera_names: + _ = image.create_dataset(cam_name, (data_size, 480, 640, 3), dtype='uint8', + chunks=(1, 480, 640, 3), ) + if args.use_depth_image: + image_depth = obs.create_group('images_depth') + for cam_name in args.camera_names: + _ = image_depth.create_dataset(cam_name, (data_size, 480, 640), dtype='uint16', + chunks=(1, 480, 640), ) + + _ = obs.create_dataset('qpos', (data_size, 14)) + _ = obs.create_dataset('qvel', (data_size, 14)) + _ = obs.create_dataset('effort', (data_size, 14)) + _ = root.create_dataset('action', (data_size, 14)) + _ = root.create_dataset('base_action', (data_size, 2)) + + # data_dict write into h5py.File + for name, array in data_dict.items(): + root[name][...] = array + print(f'\033[32m\nSaving: {time.time() - t0:.1f} secs. %s \033[0m\n'%dataset_path) + + +def is_headless(): + """ + Check if the environment is headless (no display available). + + Returns: + bool: True if the environment is headless, False otherwise. + """ + try: + import tkinter as tk + root = tk.Tk() + root.withdraw() + root.update() + root.destroy() + return False + except: + return True + + +def init_keyboard_listener(): + """ + Initialize keyboard listener for control events with new key mappings: + - Left arrow: Start data recording + - Right arrow: Save current data + - Down arrow: Discard current data + - Up arrow: Replay current data + - ESC: Early termination + + Returns: + tuple: (listener, events) - Keyboard listener and events dictionary + """ + events = { + "exit_early": False, + "record_start": False, + "save_data": False, + "discard_data": False, + "replay_data": False + } + + if is_headless(): + print( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + return None, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.left: + print("← Left arrow: STARTING data recording...") + events.update({ + "record_start": True, + "exit_early": False, + "save_data": False, + "discard_data": False + }) + + elif key == keyboard.Key.right: + print("→ Right arrow: SAVING current data...") + events.update({ + "save_data": True, + "exit_early": False, + "record_start": False + }) + + elif key == keyboard.Key.down: + print("↓ Down arrow: DISCARDING current data...") + events.update({ + "discard_data": True, + "exit_early": False, + "record_start": False + }) + + elif key == keyboard.Key.up: + print("↑ Up arrow: REPLAYING current data...") + events.update({ + "replay_data": True, + "exit_early": False + }) + + elif key == keyboard.Key.esc: + print("ESC: EARLY TERMINATION requested") + events.update({ + "exit_early": True, + "record_start": False + }) + + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events diff --git a/collect_data/visualize_episodes.py b/collect_data/visualize_episodes.py new file mode 100755 index 0000000..828e430 --- /dev/null +++ b/collect_data/visualize_episodes.py @@ -0,0 +1,159 @@ +#coding=utf-8 +import os +import numpy as np +import cv2 +import h5py +import argparse +import matplotlib.pyplot as plt + +DT = 0.02 +# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] +JOINT_NAMES = ["joint0", "joint1", "joint2", "joint3", "joint4", "joint5"] +STATE_NAMES = JOINT_NAMES + ["gripper"] +BASE_STATE_NAMES = ["linear_vel", "angular_vel"] + +def load_hdf5(dataset_dir, dataset_name): + dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5') + if not os.path.isfile(dataset_path): + print(f'Dataset does not exist at \n{dataset_path}\n') + exit() + + with h5py.File(dataset_path, 'r') as root: + is_sim = root.attrs['sim'] + compressed = root.attrs.get('compress', False) + qpos = root['/observations/qpos'][()] + qvel = root['/observations/qvel'][()] + if 'effort' in root.keys(): + effort = root['/observations/effort'][()] + else: + effort = None + action = root['/action'][()] + base_action = root['/base_action'][()] + image_dict = dict() + for cam_name in root[f'/observations/images/'].keys(): + image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()] + if compressed: + compress_len = root['/compress_len'][()] + + if compressed: + for cam_id, cam_name in enumerate(image_dict.keys()): + # un-pad and uncompress + padded_compressed_image_list = image_dict[cam_name] + image_list = [] + for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory + image_len = int(compress_len[cam_id, frame_id]) + compressed_image = padded_compressed_image + image = cv2.imdecode(compressed_image, 1) + image_list.append(image) + image_dict[cam_name] = image_list + + return qpos, qvel, effort, action, base_action, image_dict + +def main(args): + dataset_dir = args['dataset_dir'] + episode_idx = args['episode_idx'] + task_name = args['task_name'] + dataset_name = f'episode_{episode_idx}' + + qpos, qvel, effort, action, base_action, image_dict = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name) + + print('hdf5 loaded!!') + + save_videos(image_dict, action, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4')) + + + + visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png')) + visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png')) + +def save_videos(video, actions, dt, video_path=None): + cam_names = list(video.keys()) + all_cam_videos = [] + for cam_name in cam_names: + all_cam_videos.append(video[cam_name]) + all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension + + n_frames, h, w, _ = all_cam_videos.shape + fps = int(1 / dt) + out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h)) + for t in range(n_frames): + image = all_cam_videos[t] + image = image[:, :, [2, 1, 0]] # swap B and R channel + cv2.imshow("images",image) + cv2.waitKey(30) + print("episode_id: ", t, "left: ", np.round(actions[t][:7], 3), "right: ", np.round(actions[t][7:], 3), "\n") + out.write(image) + out.release() + print(f'Saved video to: {video_path}') + + +def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None): + if label_overwrite: + label1, label2 = label_overwrite + else: + label1, label2 = 'State', 'Command' + + qpos = np.array(qpos_list) # ts, dim + command = np.array(command_list) + + num_ts, num_dim = qpos.shape + h, w = 2, num_dim + num_figs = num_dim + fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim)) + + # plot joint state + all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES] + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(qpos[:, dim_idx], label=label1, color='orangered') + ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}') + ax.legend() + + # plot arm command + # for dim_idx in range(num_dim): + # ax = axs[dim_idx] + # ax.plot(command[:, dim_idx], label=label2) + # ax.legend() + + if ylim: + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.set_ylim(ylim) + + plt.tight_layout() + plt.savefig(plot_path) + print(f'Saved qpos plot to: {plot_path}') + plt.close() + + +def visualize_base(readings, plot_path=None): + readings = np.array(readings) # ts, dim + num_ts, num_dim = readings.shape + num_figs = num_dim + fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim)) + + # plot joint state + all_names = BASE_STATE_NAMES + for dim_idx in range(num_dim): + ax = axs[dim_idx] + ax.plot(readings[:, dim_idx], label='raw') + ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20') + ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10') + ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5') + ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}') + ax.legend() + + + plt.tight_layout() + plt.savefig(plot_path) + print(f'Saved effort plot to: {plot_path}') + plt.close() + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True) + parser.add_argument('--task_name', action='store', type=str, help='Task name.', + default="aloha_mobile_dummy", required=False) + parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False) + + main(vars(parser.parse_args()))