diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..8e255c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +cobot_magic/ +librealsense/ \ No newline at end of file diff --git a/collect_data/README.MD b/collect_data/README.MD new file mode 100644 index 0000000..9e4d14a --- /dev/null +++ b/collect_data/README.MD @@ -0,0 +1,3 @@ +python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data + +python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1 \ No newline at end of file diff --git a/collect_data/__pycache__/agilex_robot.cpython-310.pyc b/collect_data/__pycache__/agilex_robot.cpython-310.pyc new file mode 100644 index 0000000..6eb2734 Binary files /dev/null and b/collect_data/__pycache__/agilex_robot.cpython-310.pyc differ diff --git a/collect_data/__pycache__/ros_robot.cpython-310.pyc b/collect_data/__pycache__/ros_robot.cpython-310.pyc new file mode 100644 index 0000000..b95e9b2 Binary files /dev/null and b/collect_data/__pycache__/ros_robot.cpython-310.pyc differ diff --git a/collect_data/__pycache__/rosoperator.cpython-310.pyc b/collect_data/__pycache__/rosoperator.cpython-310.pyc new file mode 100644 index 0000000..7c51e60 Binary files /dev/null and b/collect_data/__pycache__/rosoperator.cpython-310.pyc differ diff --git a/collect_data/__pycache__/rosrobot.cpython-310.pyc b/collect_data/__pycache__/rosrobot.cpython-310.pyc new file mode 100644 index 0000000..17b14c6 Binary files /dev/null and b/collect_data/__pycache__/rosrobot.cpython-310.pyc differ diff --git a/collect_data/agilex.yaml b/collect_data/agilex.yaml new file mode 100644 index 0000000..5a9f04a --- /dev/null +++ b/collect_data/agilex.yaml @@ -0,0 +1,140 @@ +robot_type: aloha_agilex +ros_node_name: record_episodes +cameras: + cam_front: + img_topic_name: /camera_f/color/image_raw + depth_topic_name: /camera_f/depth/image_raw + width: 480 + height: 640 + rgb_shape: [480, 640, 3] + cam_left: + img_topic_name: /camera_l/color/image_raw + depth_topic_name: /camera_l/depth/image_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + cam_right: + img_topic_name: /camera_r/color/image_raw + depth_topic_name: /camera_r/depth/image_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + +arm: + master_left: + topic_name: /master/joint_left + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none" + ] + master_right: + topic_name: /master/joint_right + motors: [ + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + puppet_left: + topic_name: /puppet/joint_left + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none" + ] + puppet_right: + topic_name: /puppet/joint_right + motors: [ + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +# follow the joint name in ros +state: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +velocity: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +effort: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +action: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] diff --git a/collect_data/agilex_robot.py b/collect_data/agilex_robot.py new file mode 100644 index 0000000..28a701e --- /dev/null +++ b/collect_data/agilex_robot.py @@ -0,0 +1,456 @@ +import yaml +import cv2 +import numpy as np +import collections +import dm_env +import argparse +from typing import Dict, List, Any, Optional +from collections import deque +import rospy +from cv_bridge import CvBridge +from std_msgs.msg import Header +from sensor_msgs.msg import Image, JointState +from nav_msgs.msg import Odometry +from rosrobot import Robot +import torch +import time + + +class AgilexRobot(Robot): + def get_frame(self) -> Optional[Dict[str, Any]]: + """ + 获取同步帧数据 + 返回: 包含同步数据的字典,或None如果同步失败 + """ + # 检查基本数据可用性 + # print(self.sync_img_queues.values()) + if any(len(q) == 0 for q in self.sync_img_queues.values()): + print("camera has not get image data") + return None + + if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()): + return None + + # print(self.sync_arm_queues.values()) + # if any(len(q) == 0 for q in self.sync_arm_queues.values()): + # print("2") + # if len(self.sync_arm_queues['master_left']) == 0: + # print("can not get data from master topic") + # if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0: + # print("can not get data from puppet topic") + # return None + + if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0: + print("can not get data from puppet topic") + return None + + # 计算最小时间戳 + timestamps = [ + q[-1].header.stamp.to_sec() + for q in list(self.sync_img_queues.values()) + + list(self.sync_arm_queues.values()) + ] + + if self.use_depth_image: + timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values()) + + if self.use_robot_base and len(self.sync_base_queue) > 0: + timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec()) + + min_time = min(timestamps) + + # 检查数据同步性 + for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()): + if queue[-1].header.stamp.to_sec() < min_time: + return None + + if self.use_depth_image: + for queue in self.sync_depth_queues.values(): + if queue[-1].header.stamp.to_sec() < min_time: + return None + + if self.use_robot_base and len(self.sync_base_queue) > 0: + if self.sync_base_queue[-1].header.stamp.to_sec() < min_time: + return None + + # 提取同步数据 + frame_data = { + 'images': {}, + 'arms': {}, + 'timestamp': min_time + } + + # 图像数据 + for cam_name, queue in self.sync_img_queues.items(): + while queue[0].header.stamp.to_sec() < min_time: + queue.popleft() + frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough') + + # 深度数据 + if self.use_depth_image: + frame_data['depths'] = {} + for cam_name, queue in self.sync_depth_queues.items(): + while queue[0].header.stamp.to_sec() < min_time: + queue.popleft() + depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough') + # 保持原有的边界填充 + frame_data['depths'][cam_name] = cv2.copyMakeBorder( + depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0 + ) + + # 机械臂数据 + for arm_name, queue in self.sync_arm_queues.items(): + while queue[0].header.stamp.to_sec() < min_time: + queue.popleft() + frame_data['arms'][arm_name] = queue.popleft() + + # 基座数据 + if self.use_robot_base and len(self.sync_base_queue) > 0: + while self.sync_base_queue[0].header.stamp.to_sec() < min_time: + self.sync_base_queue.popleft() + frame_data['base'] = self.sync_base_queue.popleft() + + return frame_data + + + def teleop_step(self) -> Optional[tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]: + """ + 获取同步帧数据,输出格式与 teleop_step 一致 + 返回: (obs_dict, action_dict) 或 None 如果同步失败 + """ + # 检查基本数据可用性 + if any(len(q) == 0 for q in self.sync_img_queues.values()): + return None, None + + if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()): + return None, None + + if any(len(q) == 0 for q in self.sync_arm_queues.values()): + return None, None + + # 计算最小时间戳 + timestamps = [ + q[-1].header.stamp.to_sec() + for q in list(self.sync_img_queues.values()) + + list(self.sync_arm_queues.values()) + ] + + if self.use_depth_image: + timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values()) + + if self.use_robot_base and len(self.sync_base_queue) > 0: + timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec()) + + min_time = min(timestamps) + + # 检查数据同步性 + for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()): + if queue[-1].header.stamp.to_sec() < min_time: + return None, None + + if self.use_depth_image: + for queue in self.sync_depth_queues.values(): + if queue[-1].header.stamp.to_sec() < min_time: + return None, None + + if self.use_robot_base and len(self.sync_base_queue) > 0: + if self.sync_base_queue[-1].header.stamp.to_sec() < min_time: + return None, None + + # 初始化输出字典 + obs_dict = {} + action_dict = {} + + # 处理图像数据 + for cam_name, queue in self.sync_img_queues.items(): + while queue[0].header.stamp.to_sec() < min_time: + queue.popleft() + img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough') + obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(img) + + # 处理深度数据 + if self.use_depth_image: + for cam_name, queue in self.sync_depth_queues.items(): + while queue[0].header.stamp.to_sec() < min_time: + queue.popleft() + depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough') + depth_img = cv2.copyMakeBorder( + depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0 + ) + obs_dict[f"observation.images.depth_{cam_name}"] = torch.from_numpy(depth_img).unsqueeze(-1) + + # 处理机械臂观测数据 + arm_states = [] + arm_velocity = [] + arm_effort = [] + actions = [] + for arm_name, queue in self.sync_arm_queues.items(): + while queue[0].header.stamp.to_sec() < min_time: + queue.popleft() + arm_data = queue.popleft() + + # np.array(arm_data.position), + # np.array(arm_data.velocity), + # np.array(arm_data.effort) + # 如果是从臂(puppet),作为观测 + if arm_name.startswith('puppet'): + arm_states.append(np.array(arm_data.position, dtype=np.float32)) + arm_velocity.append(np.array(arm_data.velocity, dtype=np.float32)) + arm_effort.append(np.array(arm_data.effort, dtype=np.float32)) + + # 如果是主臂(master),作为动作 + if arm_name.startswith('master'): + # action_dict[f"action.{arm_name}"] = torch.from_numpy(np.array(arm_data.position)) + actions.append(np.array(arm_data.position, dtype=np.float32)) + + if arm_states: + obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表 + + if arm_velocity: + obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1)) + + if arm_effort: + obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1)) + + if actions: + action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1)) + # action_dict["action"] = np.concatenate(actions).squeeze() + + # 处理基座数据 + if self.use_robot_base and len(self.sync_base_queue) > 0: + while self.sync_base_queue[0].header.stamp.to_sec() < min_time: + self.sync_base_queue.popleft() + base_data = self.sync_base_queue.popleft() + obs_dict["observation.base_vel"] = torch.tensor([ + base_data.twist.twist.linear.x, + base_data.twist.twist.angular.z + ], dtype=torch.float32) + + # 添加时间戳 + # obs_dict["observation.timestamp"] = torch.tensor(min_time, dtype=torch.float64) + + return obs_dict, action_dict + + + def capture_observation(self): + """Capture observation data from ROS topics without batch dimension. + + Returns: + dict: Observation dictionary containing state and images. + + Raises: + RobotDeviceNotConnectedError: If robot is not connected. + """ + # Initialize observation dictionary + obs_dict = {} + + # Get synchronized frame data + frame_data = self.get_frame() + if frame_data is None: + # raise RuntimeError("Failed to capture synchronized observation data") + return None + + # Process arm state data (from puppet arms) + arm_states = [] + arm_velocity = [] + arm_effort = [] + + for arm_name, joint_state in frame_data['arms'].items(): + if arm_name.startswith('puppet'): + # Record timing for performance monitoring + before_read_t = time.perf_counter() + + # Get position data and convert to tensor + pos = torch.from_numpy(np.array(joint_state.position, dtype=np.float32)) + arm_states.append(pos) + + velocity = torch.from_numpy(np.array(joint_state.velocity, dtype=np.float32)) + arm_velocity.append(velocity) + + effort = torch.from_numpy(np.array(joint_state.effort, dtype=np.float32)) + arm_effort.append(effort) + + # Log timing information + # self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t + print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t) + + # Combine all arm states into single tensor + if arm_states: + obs_dict["observation.state"] = torch.cat(arm_states) + + if arm_velocity: + obs_dict["observation.velocity"] = torch.cat(arm_velocity) + + if arm_effort: + obs_dict["observation.effort"] = torch.cat(arm_effort) + + # Process image data + for cam_name, img in frame_data['images'].items(): + # Record timing for performance monitoring + before_camread_t = time.perf_counter() + + # Convert image to tensor + img_tensor = torch.from_numpy(img) + obs_dict[f"observation.images.{cam_name}"] = img_tensor + + # Log timing information + # self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t + print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t) + + # Process depth data if enabled + if self.use_depth_image and 'depths' in frame_data: + for cam_name, depth_img in frame_data['depths'].items(): + before_depthread_t = time.perf_counter() + + # Convert depth image to tensor and add channel dimension + depth_tensor = torch.from_numpy(depth_img).unsqueeze(-1) + obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor + + # self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t + print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t) + + # Process base velocity if enabled + if self.use_robot_base and 'base' in frame_data: + base_data = frame_data['base'] + obs_dict["observation.base_vel"] = torch.tensor([ + base_data.twist.twist.linear.x, + base_data.twist.twist.angular.z + ], dtype=torch.float32) + + return obs_dict + + def send_action(self, action: torch.Tensor) -> torch.Tensor: + """ + Send joint position commands to the puppet arms via ROS. + + Args: + action: Tensor containing concatenated goal positions for all puppet arms + Shape should match the action space defined in features["action"] + + Returns: + The actual action that was sent (may be clipped if safety checks are implemented) + """ + # if not hasattr(self, 'puppet_arm_publishers'): + # # Initialize publishers on first call + # self._init_action_publishers() + + last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + + # Convert tensor to numpy array if needed + if isinstance(action, torch.Tensor): + action = action.detach().cpu().numpy() + + # Split action into individual arm commands based on config + from_idx = 0 + to_idx = 0 + action_sent = [] + for arm_name, arm_config in self.arms.items(): + # 主臂topic是否存在 + if not "master" in arm_name: + continue + + # Get number of joints for this arm + num_joints = len(arm_config.get('motors', [])) + to_idx += num_joints + + # Extract this arm's portion of the action + arm_action = action[from_idx:to_idx] + arm_velocity = last_velocity[from_idx:to_idx] + arm_effort = last_effort[from_idx:to_idx] + from_idx = to_idx + + # Apply safety checks if configured + if 'max_relative_target' in self.config: + # Get current position from the queue + if len(self.sync_arm_queues[arm_name]) > 0: + current_state = self.sync_arm_queues[arm_name][-1] + current_pos = np.array(current_state.position) + + # Clip the action to stay within max relative target + max_delta = self.config['max_relative_target'] + clipped_action = np.clip(arm_action, + current_pos - max_delta, + current_pos + max_delta) + arm_action = clipped_action + + action_sent.append(arm_action) + + # Create and publish JointState message + joint_state = JointState() + joint_state.header = Header() + joint_state.header.stamp = rospy.Time.now() + joint_state.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] + joint_state.position = arm_action.tolist() + joint_state.velocity = arm_velocity + joint_state.effort = arm_effort + + # Publish to the corresponding topic + self.publishers[f"arm_{arm_name}"].publish(joint_state) + + return torch.from_numpy(np.concatenate(action_sent)) if action_sent else torch.tensor([]) + + # def _init_action_publishers(self) -> None: + # """Initialize ROS publishers for puppet arms""" + # self.puppet_arm_publishers = {} + # # rospy.init_node("replay_node") + # for arm_name, arm_config in self.arms.items(): + # if not "puppet" in arm_name: + # # if not "master" in arm_name: + # continue + + # if 'topic_name' not in arm_config: + # rospy.logwarn(f"No puppet topic defined for arm {arm_name}") + # continue + + # self.puppet_arm_publishers[arm_name] = rospy.Publisher( + # arm_config['topic_name'], + # JointState, + # queue_size=10 + # ) + + # # Wait for publisher to connect + # rospy.sleep(0.1) + + # rospy.loginfo("Initialized puppet arm publishers") + + + + +# def get_arguments() -> argparse.Namespace: +# """获取运行时参数""" +# parser = argparse.ArgumentParser() +# parser.add_argument('--fps', type=int, help='Frame rate', default=30) +# parser.add_argument('--max_timesteps', type=int, help='Max timesteps', default=500) +# parser.add_argument('--episode_idx', type=int, help='Episode index', default=0) +# parser.add_argument('--use_depth', action='store_true', help='Use depth images') +# parser.add_argument('--use_base', action='store_true', help='Use robot base') +# return parser.parse_args() + + +# if __name__ == "__main__": +# # 示例用法 +# import json +# args = get_arguments() +# robot = AgilexRobot(config_file="/home/jgl20/LYT/work/collect_data_lerobot/1.yaml", args=args) +# print(json.dumps(robot.features, indent=4)) +# robot.warmup_record() +# count = 0 +# print_flag = True +# rate = rospy.Rate(args.fps) +# while (count < args.max_timesteps + 1) and not rospy.is_shutdown(): +# a, b = robot.teleop_step() + +# if a is None or b is None: +# if print_flag: +# print("syn fail\n") +# print_flag = False +# rate.sleep() +# continue +# else: +# print(a) + + + # timesteps, actions = robot.process() + # print(timesteps) + print() diff --git a/collect_data/collect_data_lerobot.py b/collect_data/collect_data_lerobot.py new file mode 100644 index 0000000..5a6cd22 --- /dev/null +++ b/collect_data/collect_data_lerobot.py @@ -0,0 +1,487 @@ +import logging +import time +from dataclasses import asdict +from pprint import pformat +from pprint import pprint + +# from safetensors.torch import load_file, save_file +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.control_configs import ( + CalibrateControlConfig, + ControlPipelineConfig, + RecordControlConfig, + RemoteRobotConfig, + ReplayControlConfig, + TeleoperateControlConfig, +) +from lerobot.common.robot_devices.control_utils import ( + # init_keyboard_listener, + record_episode, + stop_recording, + is_headless +) +from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config +from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect +from lerobot.common.utils.utils import has_method, init_logging, log_say +from lerobot.common.utils.utils import get_safe_torch_device +from contextlib import nullcontext +from copy import copy +import torch +import rospy +import cv2 +from lerobot.configs import parser +from agilex_robot import AgilexRobot + + +######################################################################################## +# Control modes +######################################################################################## + + +def predict_action(observation, policy, device, use_amp): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + +def control_loop( + robot, + control_time_s=None, + teleoperate=False, + display_cameras=False, + dataset: LeRobotDataset | None = None, + events=None, + policy = None, + fps: int | None = None, + single_task: str | None = None, +): + # TODO(rcadene): Add option to record logs + # if not robot.is_connected: + # robot.connect() + + if events is None: + events = {"exit_early": False} + + if control_time_s is None: + control_time_s = float("inf") + + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + if dataset is not None and fps is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + rate = rospy.Rate(fps) + print_flag = True + while timestamp < control_time_s and not rospy.is_shutdown(): + # print(timestamp < control_time_s) + # print(rospy.is_shutdown()) + start_loop_t = time.perf_counter() + + if teleoperate: + observation, action = robot.teleop_step() + if observation is None or action is None: + if print_flag: + print("sync data fail, retrying...\n") + print_flag = False + rate.sleep() + continue + else: + # pass + observation = robot.capture_observation() + if policy is not None: + pred_action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} + + if dataset is not None: + frame = {**observation, **action, "task": single_task} + dataset.add_frame(frame) + + # if display_cameras and not is_headless(): + # image_keys = [key for key in observation if "image" in key] + # for key in image_keys: + # if "depth" in key: + # pass + # else: + # cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + + # print(1) + # cv2.waitKey(1) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + + # 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整) + screen_width = 1920 + screen_height = 1080 + + # 计算窗口的排列方式 + num_images = len(image_keys) + max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640 + rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数 + columns = min(num_images, max_columns) # 实际使用的列数 + + # 遍历所有图像键并显示 + for idx, key in enumerate(image_keys): + if "depth" in key: + continue # 跳过深度图像 + + # 将图像从 RGB 转换为 BGR 格式 + image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + + # 创建窗口 + cv2.imshow(key, image) + + # 计算窗口位置 + window_width = 640 + window_height = 480 + row = idx // max_columns + col = idx % max_columns + x_position = col * window_width + y_position = row * window_height + + # 移动窗口到指定位置 + cv2.moveWindow(key, x_position, y_position) + + # 等待 1 毫秒以处理事件 + cv2.waitKey(1) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + # log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["record_start"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + events["record_start"] = False + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + elif key == keyboard.Key.up: + print("Up arrow pressed. Start data recording...") + events["record_start"] = True + + + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def stop_recording(robot, listener, display_cameras): + + if not is_headless(): + if listener is not None: + listener.stop() + + if display_cameras: + cv2.destroyAllWindows() + + +def record_episode( + robot, + dataset, + events, + episode_time_s, + display_cameras, + policy, + fps, + single_task, +): + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_cameras=display_cameras, + dataset=dataset, + events=events, + policy=policy, + fps=fps, + teleoperate=policy is None, + single_task=single_task, + ) + + +def record( + robot, + cfg +) -> LeRobotDataset: + # TODO(rcadene): Add option to record logs + if cfg.resume: + dataset = LeRobotDataset( + cfg.repo_id, + root=cfg.root, + ) + if len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.num_image_writer_processes, + num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + # sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) + else: + # Create empty dataset or load existing saved episodes + # sanity_check_dataset_name(cfg.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.root, + robot=None, + features=robot.features, + use_videos=cfg.video, + image_writer_processes=cfg.num_image_writer_processes, + image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + # policy = None + + # if not robot.is_connected: + # robot.connect() + + listener, events = init_keyboard_listener() + + # Execute a few seconds without recording to: + # 1. teleoperate the robot to move it in starting position if no policy provided, + # 2. give times to the robot devices to connect and start synchronizing, + # 3. place the cameras windows on screen + enable_teleoperation = policy is None + log_say("Warmup record", cfg.play_sounds) + print() + print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n") + # warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) + + # if has_method(robot, "teleop_safety_stop"): + # robot.teleop_safety_stop() + + recorded_episodes = 0 + while True: + if recorded_episodes >= cfg.num_episodes: + break + + # if events["record_start"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}") + record_episode( + robot=robot, + dataset=dataset, + events=events, + episode_time_s=cfg.episode_time_s, + display_cameras=cfg.display_cameras, + policy=policy, + fps=cfg.fps, + single_task=cfg.single_task, + ) + + # Execute a few seconds without recording to give time to manually reset the environment + # Current code logic doesn't allow to teleoperate during this time. + # TODO(rcadene): add an option to enable teleoperation during reset + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + pprint("Reset the environment, stop recording") + # reset_environment(robot, events, cfg.reset_time_s, cfg.fps) + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + pprint("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 + + if events["stop_recording"]: + break + + log_say("Stop recording", cfg.play_sounds, blocking=True) + stop_recording(robot, listener, cfg.display_cameras) + + if cfg.push_to_hub: + dataset.push_to_hub(tags=cfg.tags, private=cfg.private) + + log_say("Exiting", cfg.play_sounds) + return dataset + + +def replay( + robot: AgilexRobot, + cfg, +): + # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset + # TODO(rcadene): Add option to record logs + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) + actions = dataset.hf_dataset.select_columns("action") + + # if not robot.is_connected: + # robot.connect() + + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / cfg.fps - dt_s) + + dt_s = time.perf_counter() - start_episode_t + # log_control_info(robot, dt_s, fps=cfg.fps) + + +import argparse +def get_arguments(): + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" + args.root = "/home/ubuntu/LYT/aloha_lerobot/data4" + args.episode = 0 # replay episode + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 100 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "move the bottle from the right to the scale right" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy = None + # args.teleoprate = True + args.control_type = "record" + # args.control_type = "replay" + return args + + + +# @parser.wrap() +# def control_robot(cfg: ControlPipelineConfig): +# init_logging() +# logging.info(pformat(asdict(cfg))) + +# # robot = make_robot_from_config(cfg.robot) +# from agilex_robot import AgilexRobot +# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + +# if isinstance(cfg.control, RecordControlConfig): +# print(cfg.control) +# record(robot, cfg.control) +# elif isinstance(cfg.control, ReplayControlConfig): +# replay(robot, cfg.control) + +# # if robot.is_connected: +# # # Disconnect manually to avoid a "Core dump" during process +# # # termination due to camera threads not properly exiting. +# # robot.disconnect() + + +# @parser.wrap() +def control_robot(cfg): + + # robot = make_robot_from_config(cfg.robot) + from agilex_robot import AgilexRobot + robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + + if cfg.control_type == "record": + record(robot, cfg) + elif cfg.control_type == "replay": + replay(robot, cfg) + + # if robot.is_connected: + # # Disconnect manually to avoid a "Core dump" during process + # # termination due to camera threads not properly exiting. + # robot.disconnect() + +if __name__ == "__main__": + cfg = get_arguments() + control_robot(cfg) + # control_robot() + # cfg = get_arguments() + # from agilex_robot import AgilexRobot + # robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + # print(robot.features.items()) + # print([key for key, ft in robot.features.items() if ft["dtype"] == "video"]) + # record(robot, cfg) + # capture = robot.capture_observation() + # import torch + # torch.save(capture, "test.pt") + # action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094, + # 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]], + # device='cpu') + # robot.send_action(action.squeeze(0)) + # print() \ No newline at end of file diff --git a/collect_data/export_env.bash b/collect_data/export_env.bash new file mode 100644 index 0000000..58be2a7 --- /dev/null +++ b/collect_data/export_env.bash @@ -0,0 +1,3 @@ +export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5 +# export LD_LIBRARY_PATH=/home/ubuntu/miniconda3/envs/lerobot/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH diff --git a/collect_data/inference.py b/collect_data/inference.py new file mode 100644 index 0000000..34f7f52 --- /dev/null +++ b/collect_data/inference.py @@ -0,0 +1,769 @@ +#!/home/lin/software/miniconda3/envs/aloha/bin/python +# -- coding: UTF-8 +""" +#!/usr/bin/python3 +""" + +import torch +import numpy as np +import os +import pickle +import argparse +from einops import rearrange +import collections +from collections import deque + +import rospy +from std_msgs.msg import Header +from geometry_msgs.msg import Twist +from sensor_msgs.msg import JointState, Image +from nav_msgs.msg import Odometry +from cv_bridge import CvBridge +import time +import threading +import math +import threading + + + + +import sys +sys.path.append("./") + +SEED = 42 +torch.manual_seed(SEED) +np.random.seed(SEED) + +task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']} + +inference_thread = None +inference_lock = threading.Lock() +inference_actions = None +inference_timestep = None + + +def actions_interpolation(args, pre_action, actions, stats): + steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0) + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + result = [pre_action] + post_action = post_process(actions[0]) + # print("pre_action:", pre_action[7:]) + # print("actions_interpolation1:", post_action[:, 7:]) + max_diff_index = 0 + max_diff = -1 + for i in range(post_action.shape[0]): + diff = 0 + for j in range(pre_action.shape[0]): + if j == 6 or j == 13: + continue + diff += math.fabs(pre_action[j] - post_action[i][j]) + if diff > max_diff: + max_diff = diff + max_diff_index = i + + for i in range(max_diff_index, post_action.shape[0]): + step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])]) + inter = np.linspace(result[-1], post_action[i], step+2) + result.extend(inter[1:]) + while len(result) < args.chunk_size+1: + result.append(result[-1]) + result = np.array(result)[1:args.chunk_size+1] + # print("actions_interpolation2:", result.shape, result[:, 7:]) + result = pre_process(result) + result = result[np.newaxis, :] + return result + + +def get_model_config(args): + # 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。 + set_seed(1) + + # 如果是ACT策略 + # fixed parameters + if args.policy_class == 'ACT': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': args.chunk_size, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base, + 'kl_weight': args.kl_weight, # kl散度权重 + 'hidden_dim': args.hidden_dim, # 隐藏层维度 + 'dim_feedforward': args.dim_feedforward, + 'enc_layers': args.enc_layers, + 'dec_layers': args.dec_layers, + 'nheads': args.nheads, + 'dropout': args.dropout, + 'pre_norm': args.pre_norm + } + elif args.policy_class == 'CNNMLP': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': 1, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base + } + + elif args.policy_class == 'Diffusion': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': args.chunk_size, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base, + 'observation_horizon': args.observation_horizon, + 'action_horizon': args.action_horizon, + 'num_inference_timesteps': args.num_inference_timesteps, + 'ema_power': args.ema_power + } + else: + raise NotImplementedError + + config = { + 'ckpt_dir': args.ckpt_dir, + 'ckpt_name': args.ckpt_name, + 'ckpt_stats_name': args.ckpt_stats_name, + 'episode_len': args.max_publish_step, + 'state_dim': args.state_dim, + 'policy_class': args.policy_class, + 'policy_config': policy_config, + 'temporal_agg': args.temporal_agg, + 'camera_names': task_config['camera_names'], + } + return config + + +def make_policy(policy_class, policy_config): + if policy_class == 'ACT': + policy = ACTPolicy(policy_config) + elif policy_class == 'CNNMLP': + policy = CNNMLPPolicy(policy_config) + elif policy_class == 'Diffusion': + policy = DiffusionPolicy(policy_config) + else: + raise NotImplementedError + return policy + + +def get_image(observation, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w') + + curr_images.append(curr_image) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def get_depth_image(observation, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_images.append(observation['images_depth'][cam_name]) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def inference_process(args, config, ros_operator, policy, stats, t, pre_action): + global inference_lock + global inference_actions + global inference_timestep + print_flag = True + pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"] + rate = rospy.Rate(args.publish_rate) + while True and not rospy.is_shutdown(): + result = ros_operator.get_frame() + if not result: + if print_flag: + print("syn fail") + print_flag = False + rate.sleep() + continue + print_flag = True + (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, robot_base) = result + obs = collections.OrderedDict() + image_dict = dict() + + image_dict[config['camera_names'][0]] = img_front + image_dict[config['camera_names'][1]] = img_left + image_dict[config['camera_names'][2]] = img_right + + + obs['images'] = image_dict + + if args.use_depth_image: + image_depth_dict = dict() + image_depth_dict[config['camera_names'][0]] = img_front_depth + image_depth_dict[config['camera_names'][1]] = img_left_depth + image_depth_dict[config['camera_names'][2]] = img_right_depth + obs['images_depth'] = image_depth_dict + + obs['qpos'] = np.concatenate( + (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0) + obs['qvel'] = np.concatenate( + (np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0) + obs['effort'] = np.concatenate( + (np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0) + if args.use_robot_base: + obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z] + obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0) + else: + obs['base_vel'] = [0.0, 0.0] + # qpos_numpy = np.array(obs['qpos']) + + # 归一化处理qpos 并转到cuda + qpos = pre_pos_process(obs['qpos']) + qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) + # 当前图像curr_image获取图像 + curr_image = get_image(obs, config['camera_names']) + curr_depth_image = None + if args.use_depth_image: + curr_depth_image = get_depth_image(obs, config['camera_names']) + start_time = time.time() + all_actions = policy(curr_image, curr_depth_image, qpos) + end_time = time.time() + print("model cost time: ", end_time -start_time) + inference_lock.acquire() + inference_actions = all_actions.cpu().detach().numpy() + if pre_action is None: + pre_action = obs['qpos'] + # print("obs['qpos']:", obs['qpos'][7:]) + if args.use_actions_interpolation: + inference_actions = actions_interpolation(args, pre_action, inference_actions, stats) + inference_timestep = t + inference_lock.release() + break + + +def model_inference(args, config, ros_operator, save_episode=True): + global inference_lock + global inference_actions + global inference_timestep + global inference_thread + set_seed(1000) + + # 1 创建模型数据 继承nn.Module + policy = make_policy(config['policy_class'], config['policy_config']) + # print("model structure\n", policy.model) + + # 2 加载模型权重 + ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name']) + state_dict = torch.load(ckpt_path) + new_state_dict = {} + for key, value in state_dict.items(): + if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]: + continue + if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]: + continue + new_state_dict[key] = value + loading_status = policy.deserialize(new_state_dict) + if not loading_status: + print("ckpt path not exist") + return False + + # 3 模型设置为cuda模式和验证模式 + policy.cuda() + policy.eval() + + # 4 加载统计值 + stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name']) + # 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维 + with open(stats_path, 'rb') as f: + stats = pickle.load(f) + + # 数据预处理和后处理函数定义 + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + + max_publish_step = config['episode_len'] + chunk_size = config['policy_config']['chunk_size'] + + # 发布基础的姿态 + left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875] + right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875] + left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] + right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883] + + ros_operator.puppet_arm_publish_continuous(left0, right0) + input("Enter any key to continue :") + ros_operator.puppet_arm_publish_continuous(left1, right1) + action = None + # 推理 + with torch.inference_mode(): + while True and not rospy.is_shutdown(): + # 每个回合的步数 + t = 0 + max_t = 0 + rate = rospy.Rate(args.publish_rate) + if config['temporal_agg']: + all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']]) + while t < max_publish_step and not rospy.is_shutdown(): + # start_time = time.time() + # query policy + if config['policy_class'] == "ACT": + if t >= max_t: + pre_action = action + inference_thread = threading.Thread(target=inference_process, + args=(args, config, ros_operator, + policy, stats, t, pre_action)) + inference_thread.start() + inference_thread.join() + inference_lock.acquire() + if inference_actions is not None: + inference_thread = None + all_actions = inference_actions + inference_actions = None + max_t = t + args.pos_lookahead_step + if config['temporal_agg']: + all_time_actions[[t], t:t + chunk_size] = all_actions + inference_lock.release() + if config['temporal_agg']: + actions_for_curr_step = all_time_actions[:, t] + actions_populated = np.all(actions_for_curr_step != 0, axis=1) + actions_for_curr_step = actions_for_curr_step[actions_populated] + k = 0.01 + exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + exp_weights = exp_weights / exp_weights.sum() + exp_weights = exp_weights[:, np.newaxis] + raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True) + else: + if args.pos_lookahead_step != 0: + raw_action = all_actions[:, t % args.pos_lookahead_step] + else: + raw_action = all_actions[:, t % chunk_size] + else: + raise NotImplementedError + action = post_process(raw_action[0]) + left_action = action[:7] # 取7维度 + right_action = action[7:14] + ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread + if args.use_robot_base: + vel_action = action[14:16] + ros_operator.robot_base_publish(vel_action) + t += 1 + # end_time = time.time() + # print("publish: ", t) + # print("time:", end_time - start_time) + # print("left_action:", left_action) + # print("right_action:", right_action) + rate.sleep() + + +class RosOperator: + def __init__(self, args): + self.robot_base_deque = None + self.puppet_arm_right_deque = None + self.puppet_arm_left_deque = None + self.img_front_deque = None + self.img_right_deque = None + self.img_left_deque = None + self.img_front_depth_deque = None + self.img_right_depth_deque = None + self.img_left_depth_deque = None + self.bridge = None + self.puppet_arm_left_publisher = None + self.puppet_arm_right_publisher = None + self.robot_base_publisher = None + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_lock = None + self.args = args + self.ctrl_state = False + self.ctrl_state_lock = threading.Lock() + self.init() + self.init_ros() + + def init(self): + self.bridge = CvBridge() + self.img_left_deque = deque() + self.img_right_deque = deque() + self.img_front_deque = deque() + self.img_left_depth_deque = deque() + self.img_right_depth_deque = deque() + self.img_front_depth_deque = deque() + self.puppet_arm_left_deque = deque() + self.puppet_arm_right_deque = deque() + self.robot_base_deque = deque() + self.puppet_arm_publish_lock = threading.Lock() + self.puppet_arm_publish_lock.acquire() + + def puppet_arm_publish(self, left, right): + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right + self.puppet_arm_right_publisher.publish(joint_state_msg) + + def robot_base_publish(self, vel): + vel_msg = Twist() + vel_msg.linear.x = vel[0] + vel_msg.linear.y = 0 + vel_msg.linear.z = 0 + vel_msg.angular.x = 0 + vel_msg.angular.y = 0 + vel_msg.angular.z = vel[1] + self.robot_base_publisher.publish(vel_msg) + + def puppet_arm_publish_continuous(self, left, right): + rate = rospy.Rate(self.args.publish_rate) + left_arm = None + right_arm = None + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))] + right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))] + flag = True + step = 0 + while flag and not rospy.is_shutdown(): + if self.puppet_arm_publish_lock.acquire(False): + return + left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))] + right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))] + flag = False + for i in range(len(left)): + if left_diff[i] < self.args.arm_steps_length[i]: + left_arm[i] = left[i] + else: + left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] + flag = True + for i in range(len(right)): + if right_diff[i] < self.args.arm_steps_length[i]: + right_arm[i] = right[i] + else: + right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i] + flag = True + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = left_arm + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right_arm + self.puppet_arm_right_publisher.publish(joint_state_msg) + step += 1 + print("puppet_arm_publish_continuous:", step) + rate.sleep() + + def puppet_arm_publish_linear(self, left, right): + num_step = 100 + rate = rospy.Rate(200) + + left_arm = None + right_arm = None + + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + + traj_left_list = np.linspace(left_arm, left, num_step) + traj_right_list = np.linspace(right_arm, right, num_step) + + for i in range(len(traj_left_list)): + traj_left = traj_left_list[i] + traj_right = traj_right_list[i] + traj_left[-1] = left[-1] + traj_right[-1] = right[-1] + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = traj_left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = traj_right + self.puppet_arm_right_publisher.publish(joint_state_msg) + rate.sleep() + + def puppet_arm_publish_continuous_thread(self, left, right): + if self.puppet_arm_publish_thread is not None: + self.puppet_arm_publish_lock.release() + self.puppet_arm_publish_thread.join() + self.puppet_arm_publish_lock.acquire(False) + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right)) + self.puppet_arm_publish_thread.start() + + def get_frame(self): + if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \ + (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)): + return False + if self.args.use_depth_image: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(), + self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()]) + else: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()]) + + if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): + return False + + while self.img_left_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_deque.popleft() + img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough') + + while self.img_right_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_deque.popleft() + img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough') + + while self.img_front_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_deque.popleft() + img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough') + + while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_left_deque.popleft() + puppet_arm_left = self.puppet_arm_left_deque.popleft() + + while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_right_deque.popleft() + puppet_arm_right = self.puppet_arm_right_deque.popleft() + + img_left_depth = None + if self.args.use_depth_image: + while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_depth_deque.popleft() + img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough') + + img_right_depth = None + if self.args.use_depth_image: + while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_depth_deque.popleft() + img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough') + + img_front_depth = None + if self.args.use_depth_image: + while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_depth_deque.popleft() + img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough') + + robot_base = None + if self.args.use_robot_base: + while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: + self.robot_base_deque.popleft() + robot_base = self.robot_base_deque.popleft() + + return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, robot_base) + + def img_left_callback(self, msg): + if len(self.img_left_deque) >= 2000: + self.img_left_deque.popleft() + self.img_left_deque.append(msg) + + def img_right_callback(self, msg): + if len(self.img_right_deque) >= 2000: + self.img_right_deque.popleft() + self.img_right_deque.append(msg) + + def img_front_callback(self, msg): + if len(self.img_front_deque) >= 2000: + self.img_front_deque.popleft() + self.img_front_deque.append(msg) + + def img_left_depth_callback(self, msg): + if len(self.img_left_depth_deque) >= 2000: + self.img_left_depth_deque.popleft() + self.img_left_depth_deque.append(msg) + + def img_right_depth_callback(self, msg): + if len(self.img_right_depth_deque) >= 2000: + self.img_right_depth_deque.popleft() + self.img_right_depth_deque.append(msg) + + def img_front_depth_callback(self, msg): + if len(self.img_front_depth_deque) >= 2000: + self.img_front_depth_deque.popleft() + self.img_front_depth_deque.append(msg) + + def puppet_arm_left_callback(self, msg): + if len(self.puppet_arm_left_deque) >= 2000: + self.puppet_arm_left_deque.popleft() + self.puppet_arm_left_deque.append(msg) + + def puppet_arm_right_callback(self, msg): + if len(self.puppet_arm_right_deque) >= 2000: + self.puppet_arm_right_deque.popleft() + self.puppet_arm_right_deque.append(msg) + + def robot_base_callback(self, msg): + if len(self.robot_base_deque) >= 2000: + self.robot_base_deque.popleft() + self.robot_base_deque.append(msg) + + def ctrl_callback(self, msg): + self.ctrl_state_lock.acquire() + self.ctrl_state = msg.data + self.ctrl_state_lock.release() + + def get_ctrl_state(self): + self.ctrl_state_lock.acquire() + state = self.ctrl_state + self.ctrl_state_lock.release() + return state + + def init_ros(self): + rospy.init_node('joint_state_publisher', anonymous=True) + rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True) + if self.args.use_depth_image: + rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True) + self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10) + self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10) + self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10) + + +def get_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) + parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False) + parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False) + parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False) + parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False) + parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False) + parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False) + parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False) + parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False) + parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False) + parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False) + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False) + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features", required=False) + parser.add_argument('--masks', action='store_true', + help="Train segmentation head if the flag is provided") + parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False) + parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False) + parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False) + parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False) + + parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False) + parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False) + parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False) + parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False) + parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False) + parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False) + parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False) + parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False) + parser.add_argument('--pre_norm', action='store_true', required=False) + + parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', + default='/camera_f/color/image_raw', required=False) + parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', + default='/camera_l/color/image_raw', required=False) + parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', + default='/camera_r/color/image_raw', required=False) + + parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic', + default='/camera_f/depth/image_raw', required=False) + parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic', + default='/camera_l/depth/image_raw', required=False) + parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic', + default='/camera_r/depth/image_raw', required=False) + + parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic', + default='/master/joint_left', required=False) + parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic', + default='/master/joint_right', required=False) + parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', + default='/puppet/joint_left', required=False) + parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', + default='/puppet/joint_right', required=False) + + parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', + default='/odom_raw', required=False) + parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic', + default='/cmd_vel', required=False) + parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', + default=False, required=False) + parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate', + default=40, required=False) + parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step', + default=0, required=False) + parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', + default=32, required=False) + parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length', + default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False) + + parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation', + default=False, required=False) + parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image', + default=False, required=False) + + # for Diffusion + parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False) + parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False) + parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False) + parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False) + args = parser.parse_args() + return args + + +def main(): + args = get_arguments() + ros_operator = RosOperator(args) + config = get_model_config(args) + model_inference(args, config, ros_operator, save_episode=True) + + +if __name__ == '__main__': + main() +# python act/inference.py --ckpt_dir ~/train0314/ \ No newline at end of file diff --git a/collect_data/read_parquet.py b/collect_data/read_parquet.py new file mode 100644 index 0000000..577a1e3 --- /dev/null +++ b/collect_data/read_parquet.py @@ -0,0 +1,33 @@ +import pandas as pd + +def read_and_print_parquet_row(file_path, row_index=0): + """ + 读取Parquet文件并打印指定行的数据 + + 参数: + file_path (str): Parquet文件路径 + row_index (int): 要打印的行索引(默认为第0行) + """ + try: + # 读取Parquet文件 + df = pd.read_parquet(file_path) + + # 检查行索引是否有效 + if row_index >= len(df): + print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)") + return + + # 打印指定行数据 + print(f"文件: {file_path}") + print(f"第 {row_index} 行数据:\n{'-'*30}") + print(df.iloc[row_index]) + + except FileNotFoundError: + print(f"错误: 文件 {file_path} 不存在") + except Exception as e: + print(f"读取失败: {str(e)}") + +# 示例用法 +if __name__ == "__main__": + file_path = "example.parquet" # 替换为你的Parquet文件路径 + read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行 diff --git a/collect_data/replay_data.py b/collect_data/replay_data.py new file mode 100644 index 0000000..6c880dc --- /dev/null +++ b/collect_data/replay_data.py @@ -0,0 +1,112 @@ +#coding=utf-8 +import os +import numpy as np +import cv2 +import h5py +import argparse +import rospy + +from cv_bridge import CvBridge +from std_msgs.msg import Header +from sensor_msgs.msg import Image, JointState +from geometry_msgs.msg import Twist +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + + +def main(args): + rospy.init_node("replay_node") + bridge = CvBridge() + # img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10) + # img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10) + # img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10) + + # puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10) + # puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10) + + master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10) + master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10) + + # robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10) + + + # dataset_dir = args.dataset_dir + # episode_idx = args.episode_idx + # task_name = args.task_name + # dataset_name = f'episode_{episode_idx}' + + dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode]) + actions = dataset.hf_dataset.select_columns("action") + velocitys = dataset.hf_dataset.select_columns("observation.velocity") + efforts = dataset.hf_dataset.select_columns("observation.effort") + + origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279] + origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279] + + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称 + twist_msg = Twist() + + rate = rospy.Rate(args.fps) + + # qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name) + + + last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647] + last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + rate = rospy.Rate(50) + for idx in range(len(actions)): + action = actions[idx]['action'].detach().cpu().numpy() + velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy() + effort = efforts[idx]['observation.effort'].detach().cpu().numpy() + if(rospy.is_shutdown()): + break + + new_actions = np.linspace(last_action, action, 5) # 插值 + new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值 + new_efforts = np.linspace(last_effort, effort, 5) # 插值 + last_action = action + last_velocity = velocity + last_effort = effort + for act in new_actions: + print(np.round(act[:7], 4)) + cur_timestamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.header.stamp = cur_timestamp + + joint_state_msg.position = act[:7] + joint_state_msg.velocity = last_velocity[:7] + joint_state_msg.effort = last_effort[:7] + master_arm_left_publisher.publish(joint_state_msg) + + joint_state_msg.position = act[7:] + joint_state_msg.velocity = last_velocity[:7] + joint_state_msg.effort = last_effort[7:] + master_arm_right_publisher.publish(joint_state_msg) + + if(rospy.is_shutdown()): + break + rate.sleep() + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', + # default='/master/joint_left', required=False) + # parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', + # default='/master/joint_right', required=False) + + + args = parser.parse_args() + args.repo_id = "tangger/test" + args.root = "/home/ubuntu/LYT/aloha_lerobot/data1" + args.episode = 1 # replay episode + args.master_arm_left_topic = "/master/joint_left" + args.master_arm_right_topic = "/master/joint_right" + args.fps = 30 + + main(args) + # python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0 \ No newline at end of file diff --git a/collect_data/rosrobot.py b/collect_data/rosrobot.py new file mode 100644 index 0000000..3a2de20 --- /dev/null +++ b/collect_data/rosrobot.py @@ -0,0 +1,372 @@ +import yaml +from collections import deque +import rospy +from cv_bridge import CvBridge +from typing import Dict, Any, Optional, List +from sensor_msgs.msg import Image, JointState +from nav_msgs.msg import Odometry +import argparse + + +class Robot: + def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None): + """ + 机器人基类,处理通用初始化逻辑 + Args: + config_file: YAML配置文件路径 + args: 运行时参数 + """ + self._load_config(config_file) + self._merge_runtime_args(args) + self._init_components() + self._init_data_structures() + self.init_ros() + self.init_features() + self.warmup() + + def _load_config(self, config_file: str) -> None: + """加载YAML配置文件""" + with open(config_file, 'r') as f: + self.config = yaml.safe_load(f) + + def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None: + """合并运行时参数到配置""" + if args is None: + return + + runtime_params = { + 'frame_rate': getattr(args, 'fps', None), + 'max_timesteps': getattr(args, 'max_timesteps', None), + 'episode_idx': getattr(args, 'episode_idx', None), + 'use_depth_image': getattr(args, 'use_depth_image', None), + 'use_robot_base': getattr(args, 'use_base', None), + 'video': getattr(args, 'video', False), + 'control_type': getattr(args, 'control_type', False), + } + + for key, value in runtime_params.items(): + if value is not None: + self.config[key] = value + + def _init_components(self) -> None: + """初始化核心组件""" + self.bridge = CvBridge() + self.subscribers = {} + self.publishers = {} + self._validate_config() + + def _validate_config(self) -> None: + """验证配置完整性""" + required_sections = ['cameras', 'arm'] + for section in required_sections: + if section not in self.config: + raise ValueError(f"Missing required config section: {section}") + + def _init_data_structures(self) -> None: + """初始化数据结构模板方法""" + # 相机数据 + self.cameras = self.config.get('cameras', {}) + self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras} + + # 深度数据 + self.use_depth_image = self.config.get('use_depth_image', False) + if self.use_depth_image: + self.sync_depth_queues = { + name: deque(maxlen=2000) + for name, cam in self.cameras.items() + if 'depth_topic_name' in cam + } + + # 机械臂数据 + self.arms = self.config.get('arm', {}) + if self.config.get('control_type', '') != 'record': + # 如果不是录制模式,则仅初始化从机械臂数据队列 + self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name} + else: + self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms} + + # 机器人基座数据 + self.use_robot_base = self.config.get('use_robot_base', False) + if self.use_robot_base: + self.sync_base_queue = deque(maxlen=2000) + + def init_ros(self) -> None: + """初始化ROS订阅的模板方法""" + rospy.init_node( + f"{self.config.get('ros_node_name', 'generic_robot_node')}", + anonymous=True + ) + + self._setup_camera_subscribers() + self._setup_arm_subscribers_publishers() + self._setup_base_subscriber() + self._log_ros_status() + + def init_features(self): + """ + 根据YAML配置自动生成features结构 + """ + self.features = {} + + # 初始化相机特征 + self._init_camera_features() + + # 初始化机械臂特征 + self._init_state_features() + + self._init_action_features() + + # 初始化基座特征(如果启用) + if self.use_robot_base: + self._init_base_features() + import pprint + pprint.pprint(self.features, indent=4) + + + def _init_camera_features(self): + """处理所有相机特征""" + for cam_name, cam_config in self.cameras.items(): + # 普通图像 + self.features[f"observation.images.{cam_name}"] = { + "dtype": "video" if self.config.get("video", False) else "image", + "shape": cam_config.get("rgb_shape", [480, 640, 3]), + "names": ["height", "width", "channel"], + # "video_info": { + # "video.fps": cam_config.get("fps", 30.0), + # "video.codec": cam_config.get("codec", "av1"), + # "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"), + # "video.is_depth_map": False, + # "has_audio": False + # } + } + + if self.config.get("use_depth_image", False): + self.features[f"observation.images.depth_{cam_name}"] = { + "dtype": "uint16", + "shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1), + "names": ["height", "width"], + } + + + def _init_state_features(self): + state = self.config.get('state', {}) + # 状态特征 + self.features["observation.state"] = { + "dtype": "float32", + "shape": (len(state.get('motors', "")),), + "names": {"motors": state.get('motors', "")} + } + + if self.config.get('velocity'): + velocity = self.config.get('velocity', "") + self.features["observation.velocity"] = { + "dtype": "float32", + "shape": (len(velocity.get('motors', "")),), + "names": {"motors": velocity.get('motors', "")} + } + + if self.config.get('effort'): + effort = self.config.get('effort', "") + self.features["observation.effort"] = { + "dtype": "float32", + "shape": (len(effort.get('motors', "")),), + "names": {"motors": effort.get('motors', "")} + } + + + + def _init_action_features(self): + action = self.config.get('action', {}) + # 状态特征 + self.features["action"] = { + "dtype": "float32", + "shape": (len(action.get('motors', "")),), + "names": {"motors": action.get('motors', "")} + } + + def _init_base_features(self): + """处理基座特征""" + self.features["observation.base_vel"] = { + "dtype": "float32", + "shape": (2,), + "names": ["linear_x", "angular_z"] + } + + + def _setup_camera_subscribers(self) -> None: + """设置相机订阅者""" + for cam_name, cam_config in self.cameras.items(): + if 'img_topic_name' in cam_config: + self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber( + cam_config['img_topic_name'], + Image, + self._make_camera_callback(cam_name, is_depth=False), + queue_size=1000, + tcp_nodelay=True + ) + + if self.use_depth_image and 'depth_topic_name' in cam_config: + self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber( + cam_config['depth_topic_name'], + Image, + self._make_camera_callback(cam_name, is_depth=True), + queue_size=1000, + tcp_nodelay=True + ) + + def _setup_arm_subscribers_publishers(self) -> None: + """设置机械臂订阅者""" + # 当为record模式时,主从机械臂都需要订阅 + # 否则只订阅从机械臂,但向主机械臂发布 + if self.config.get('control_type', '') == 'record': + for arm_name, arm_config in self.arms.items(): + if 'topic_name' in arm_config: + self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber( + arm_config['topic_name'], + JointState, + self._make_arm_callback(arm_name), + queue_size=1000, + tcp_nodelay=True + ) + else: + for arm_name, arm_config in self.arms.items(): + if 'puppet' in arm_name: + self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber( + arm_config['topic_name'], + JointState, + self._make_arm_callback(arm_name), + queue_size=1000, + tcp_nodelay=True + ) + if 'master' in arm_name: + self.publishers[f"arm_{arm_name}"] = rospy.Publisher( + arm_config['topic_name'], + JointState, + queue_size=10 + ) + + def _setup_base_subscriber(self) -> None: + """设置基座订阅者""" + if self.use_robot_base and 'robot_base' in self.config: + self.subscribers['base'] = rospy.Subscriber( + self.config['robot_base']['topic_name'], + Odometry, + self.robot_base_callback, + queue_size=1000, + tcp_nodelay=True + ) + + def _log_ros_status(self) -> None: + """记录ROS状态""" + rospy.loginfo("\n=== ROS订阅状态 ===") + rospy.loginfo(f"已初始化节点: {rospy.get_name()}") + rospy.loginfo("活跃的订阅者:") + for topic, sub in self.subscribers.items(): + rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}") + rospy.loginfo("=================") + + def _make_camera_callback(self, cam_name: str, is_depth: bool = False): + """生成相机回调函数工厂方法""" + def callback(msg): + try: + target_queue = ( + self.sync_depth_queues[cam_name] + if is_depth + else self.sync_img_queues[cam_name] + ) + if len(target_queue) >= 2000: + target_queue.popleft() + target_queue.append(msg) + except Exception as e: + rospy.logerr(f"Camera {cam_name} callback error: {str(e)}") + return callback + + def _make_arm_callback(self, arm_name: str): + """生成机械臂回调函数工厂方法""" + def callback(msg): + try: + if len(self.sync_arm_queues[arm_name]) >= 2000: + self.sync_arm_queues[arm_name].popleft() + self.sync_arm_queues[arm_name].append(msg) + except Exception as e: + rospy.logerr(f"Arm {arm_name} callback error: {str(e)}") + return callback + + def robot_base_callback(self, msg): + """基座回调默认实现""" + if len(self.sync_base_queue) >= 2000: + self.sync_base_queue.popleft() + self.sync_base_queue.append(msg) + + def warmup(self, timeout: float = 10.0) -> bool: + """Wait until all data queues have at least 20 messages. + + Args: + timeout: Maximum time to wait in seconds before giving up + + Returns: + bool: True if warmup succeeded, False if timed out + """ + start_time = rospy.Time.now().to_sec() + rate = rospy.Rate(10) # Check at 10Hz + + rospy.loginfo("Starting warmup process...") + + while not rospy.is_shutdown(): + # Check if timeout has been reached + current_time = rospy.Time.now().to_sec() + if current_time - start_time > timeout: + rospy.logwarn("Warmup timed out before all queues were filled") + return False + + # Check all required queues + all_ready = True + + # Check camera image queues + for cam_name in self.cameras: + if len(self.sync_img_queues[cam_name]) < 50: + rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)") + all_ready = False + break + + # Check depth queues if enabled + if self.use_depth_image: + for cam_name in self.sync_depth_queues: + if len(self.sync_depth_queues[cam_name]) < 50: + rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)") + all_ready = False + break + + # # Check arm queues + # for arm_name in self.arms: + # if len(self.sync_arm_queues[arm_name]) < 20: + # rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)") + # all_ready = False + # break + + # Check base queue if enabled + if self.use_robot_base: + if len(self.sync_base_queue) < 20: + rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)") + all_ready = False + + # If all queues are ready, return success + if all_ready: + rospy.loginfo("Warmup completed successfully") + return True + + rate.sleep() + + return False + + + + + + def get_frame(self) -> Optional[Dict[str, Any]]: + """获取同步帧数据的模板方法""" + raise NotImplementedError("Subclasses must implement get_frame()") + + def process(self) -> tuple: + """主处理循环的模板方法""" + raise NotImplementedError("Subclasses must implement process()") diff --git a/collect_data/rosrobot_factory.py b/collect_data/rosrobot_factory.py new file mode 100644 index 0000000..6ca6bec --- /dev/null +++ b/collect_data/rosrobot_factory.py @@ -0,0 +1,26 @@ +import yaml +import argparse +from typing import Dict, List, Any, Optional +from rosrobot import Robot +from agilex_robot import AgilexRobot + + +class RobotFactory: + @staticmethod + def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot: + """ + 根据配置文件自动创建合适的机器人实例 + Args: + config_file: 配置文件路径 + args: 运行时参数 + """ + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + robot_type = config.get('robot_type', 'agilex') + + if robot_type == 'agilex': + return AgilexRobot(config_file, args) + # 可扩展其他机器人类型 + else: + raise ValueError(f"Unsupported robot type: {robot_type}") diff --git a/collect_data/test.py b/collect_data/test.py new file mode 100644 index 0000000..8eb8748 --- /dev/null +++ b/collect_data/test.py @@ -0,0 +1,70 @@ +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.robot_devices.utils import busy_wait +import time +import argparse +from agilex_robot import AgilexRobot +import torch + +def get_arguments(): + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "tangger/test" + args.root = "./data2" + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 50 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "test test" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy= None + args.teleoprate = False + return args + + +cfg = get_arguments() +robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) +inference_time_s = 360 +fps = 30 +device = "cuda" # TODO: On Mac, use "mps" or "cpu" + +ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model" +policy = ACTPolicy.from_pretrained(ckpt_path) +policy.to(device) + +for _ in range(inference_time_s * fps): + start_time = time.perf_counter() + + # Read the follower state and access the frames from the cameras + observation = robot.capture_observation() + if observation is None: + print("Observation is None, skipping...") + continue + + # Convert to pytorch format: channel first and float32 in [0,1] + # with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + # Remove batch dimension + action = action.squeeze(0) + # Move to cpu, if not already the case + action = action.to("cpu") + # Order the robot to move + robot.send_action(action) + + dt_s = time.perf_counter() - start_time + busy_wait(1 / fps - dt_s) \ No newline at end of file diff --git a/test.pt b/test.pt new file mode 100644 index 0000000..15d58ea Binary files /dev/null and b/test.pt differ