modify code structure

This commit is contained in:
2025-04-07 19:45:34 +08:00
parent 91c2b7b0cb
commit d843a990a3
25 changed files with 2135 additions and 333 deletions

View File

@@ -19,6 +19,12 @@ cameras:
rgb_shape: [480, 640, 3]
width: 480
height: 640
cam_high:
img_topic_name: /camera/color/image_raw
depth_topic_name: /camera/depth/image_rect_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
arm:
master_left:

View File

@@ -1,456 +0,0 @@
import yaml
import cv2
import numpy as np
import collections
import dm_env
import argparse
from typing import Dict, List, Any, Optional
from collections import deque
import rospy
from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
from rosrobot import Robot
import torch
import time
class AgilexRobot(Robot):
def get_frame(self) -> Optional[Dict[str, Any]]:
"""
获取同步帧数据
返回: 包含同步数据的字典或None如果同步失败
"""
# 检查基本数据可用性
# print(self.sync_img_queues.values())
if any(len(q) == 0 for q in self.sync_img_queues.values()):
print("camera has not get image data")
return None
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
return None
# print(self.sync_arm_queues.values())
# if any(len(q) == 0 for q in self.sync_arm_queues.values()):
# print("2")
# if len(self.sync_arm_queues['master_left']) == 0:
# print("can not get data from master topic")
# if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
# print("can not get data from puppet topic")
# return None
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
print("can not get data from puppet topic")
return None
# 计算最小时间戳
timestamps = [
q[-1].header.stamp.to_sec()
for q in list(self.sync_img_queues.values()) +
list(self.sync_arm_queues.values())
]
if self.use_depth_image:
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
if self.use_robot_base and len(self.sync_base_queue) > 0:
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
min_time = min(timestamps)
# 检查数据同步性
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
if queue[-1].header.stamp.to_sec() < min_time:
return None
if self.use_depth_image:
for queue in self.sync_depth_queues.values():
if queue[-1].header.stamp.to_sec() < min_time:
return None
if self.use_robot_base and len(self.sync_base_queue) > 0:
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
return None
# 提取同步数据
frame_data = {
'images': {},
'arms': {},
'timestamp': min_time
}
# 图像数据
for cam_name, queue in self.sync_img_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
# 深度数据
if self.use_depth_image:
frame_data['depths'] = {}
for cam_name, queue in self.sync_depth_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
# 保持原有的边界填充
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
)
# 机械臂数据
for arm_name, queue in self.sync_arm_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
frame_data['arms'][arm_name] = queue.popleft()
# 基座数据
if self.use_robot_base and len(self.sync_base_queue) > 0:
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
self.sync_base_queue.popleft()
frame_data['base'] = self.sync_base_queue.popleft()
return frame_data
def teleop_step(self) -> Optional[tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]:
"""
获取同步帧数据,输出格式与 teleop_step 一致
返回: (obs_dict, action_dict) 或 None 如果同步失败
"""
# 检查基本数据可用性
if any(len(q) == 0 for q in self.sync_img_queues.values()):
return None, None
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
return None, None
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
return None, None
# 计算最小时间戳
timestamps = [
q[-1].header.stamp.to_sec()
for q in list(self.sync_img_queues.values()) +
list(self.sync_arm_queues.values())
]
if self.use_depth_image:
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
if self.use_robot_base and len(self.sync_base_queue) > 0:
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
min_time = min(timestamps)
# 检查数据同步性
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
if queue[-1].header.stamp.to_sec() < min_time:
return None, None
if self.use_depth_image:
for queue in self.sync_depth_queues.values():
if queue[-1].header.stamp.to_sec() < min_time:
return None, None
if self.use_robot_base and len(self.sync_base_queue) > 0:
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
return None, None
# 初始化输出字典
obs_dict = {}
action_dict = {}
# 处理图像数据
for cam_name, queue in self.sync_img_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(img)
# 处理深度数据
if self.use_depth_image:
for cam_name, queue in self.sync_depth_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
depth_img = cv2.copyMakeBorder(
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
)
obs_dict[f"observation.images.depth_{cam_name}"] = torch.from_numpy(depth_img).unsqueeze(-1)
# 处理机械臂观测数据
arm_states = []
arm_velocity = []
arm_effort = []
actions = []
for arm_name, queue in self.sync_arm_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
arm_data = queue.popleft()
# np.array(arm_data.position),
# np.array(arm_data.velocity),
# np.array(arm_data.effort)
# 如果是从臂(puppet),作为观测
if arm_name.startswith('puppet'):
arm_states.append(np.array(arm_data.position, dtype=np.float32))
arm_velocity.append(np.array(arm_data.velocity, dtype=np.float32))
arm_effort.append(np.array(arm_data.effort, dtype=np.float32))
# 如果是主臂(master),作为动作
if arm_name.startswith('master'):
# action_dict[f"action.{arm_name}"] = torch.from_numpy(np.array(arm_data.position))
actions.append(np.array(arm_data.position, dtype=np.float32))
if arm_states:
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
if arm_velocity:
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
if arm_effort:
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
if actions:
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
# action_dict["action"] = np.concatenate(actions).squeeze()
# 处理基座数据
if self.use_robot_base and len(self.sync_base_queue) > 0:
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
self.sync_base_queue.popleft()
base_data = self.sync_base_queue.popleft()
obs_dict["observation.base_vel"] = torch.tensor([
base_data.twist.twist.linear.x,
base_data.twist.twist.angular.z
], dtype=torch.float32)
# 添加时间戳
# obs_dict["observation.timestamp"] = torch.tensor(min_time, dtype=torch.float64)
return obs_dict, action_dict
def capture_observation(self):
"""Capture observation data from ROS topics without batch dimension.
Returns:
dict: Observation dictionary containing state and images.
Raises:
RobotDeviceNotConnectedError: If robot is not connected.
"""
# Initialize observation dictionary
obs_dict = {}
# Get synchronized frame data
frame_data = self.get_frame()
if frame_data is None:
# raise RuntimeError("Failed to capture synchronized observation data")
return None
# Process arm state data (from puppet arms)
arm_states = []
arm_velocity = []
arm_effort = []
for arm_name, joint_state in frame_data['arms'].items():
if arm_name.startswith('puppet'):
# Record timing for performance monitoring
before_read_t = time.perf_counter()
# Get position data and convert to tensor
pos = torch.from_numpy(np.array(joint_state.position, dtype=np.float32))
arm_states.append(pos)
velocity = torch.from_numpy(np.array(joint_state.velocity, dtype=np.float32))
arm_velocity.append(velocity)
effort = torch.from_numpy(np.array(joint_state.effort, dtype=np.float32))
arm_effort.append(effort)
# Log timing information
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
# Combine all arm states into single tensor
if arm_states:
obs_dict["observation.state"] = torch.cat(arm_states)
if arm_velocity:
obs_dict["observation.velocity"] = torch.cat(arm_velocity)
if arm_effort:
obs_dict["observation.effort"] = torch.cat(arm_effort)
# Process image data
for cam_name, img in frame_data['images'].items():
# Record timing for performance monitoring
before_camread_t = time.perf_counter()
# Convert image to tensor
img_tensor = torch.from_numpy(img)
obs_dict[f"observation.images.{cam_name}"] = img_tensor
# Log timing information
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
# Process depth data if enabled
if self.use_depth_image and 'depths' in frame_data:
for cam_name, depth_img in frame_data['depths'].items():
before_depthread_t = time.perf_counter()
# Convert depth image to tensor and add channel dimension
depth_tensor = torch.from_numpy(depth_img).unsqueeze(-1)
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
# Process base velocity if enabled
if self.use_robot_base and 'base' in frame_data:
base_data = frame_data['base']
obs_dict["observation.base_vel"] = torch.tensor([
base_data.twist.twist.linear.x,
base_data.twist.twist.angular.z
], dtype=torch.float32)
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
"""
Send joint position commands to the puppet arms via ROS.
Args:
action: Tensor containing concatenated goal positions for all puppet arms
Shape should match the action space defined in features["action"]
Returns:
The actual action that was sent (may be clipped if safety checks are implemented)
"""
# if not hasattr(self, 'puppet_arm_publishers'):
# # Initialize publishers on first call
# self._init_action_publishers()
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
# Convert tensor to numpy array if needed
if isinstance(action, torch.Tensor):
action = action.detach().cpu().numpy()
# Split action into individual arm commands based on config
from_idx = 0
to_idx = 0
action_sent = []
for arm_name, arm_config in self.arms.items():
# 主臂topic是否存在
if not "master" in arm_name:
continue
# Get number of joints for this arm
num_joints = len(arm_config.get('motors', []))
to_idx += num_joints
# Extract this arm's portion of the action
arm_action = action[from_idx:to_idx]
arm_velocity = last_velocity[from_idx:to_idx]
arm_effort = last_effort[from_idx:to_idx]
from_idx = to_idx
# Apply safety checks if configured
if 'max_relative_target' in self.config:
# Get current position from the queue
if len(self.sync_arm_queues[arm_name]) > 0:
current_state = self.sync_arm_queues[arm_name][-1]
current_pos = np.array(current_state.position)
# Clip the action to stay within max relative target
max_delta = self.config['max_relative_target']
clipped_action = np.clip(arm_action,
current_pos - max_delta,
current_pos + max_delta)
arm_action = clipped_action
action_sent.append(arm_action)
# Create and publish JointState message
joint_state = JointState()
joint_state.header = Header()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', '']
joint_state.position = arm_action.tolist()
joint_state.velocity = arm_velocity
joint_state.effort = arm_effort
# Publish to the corresponding topic
self.publishers[f"arm_{arm_name}"].publish(joint_state)
return torch.from_numpy(np.concatenate(action_sent)) if action_sent else torch.tensor([])
# def _init_action_publishers(self) -> None:
# """Initialize ROS publishers for puppet arms"""
# self.puppet_arm_publishers = {}
# # rospy.init_node("replay_node")
# for arm_name, arm_config in self.arms.items():
# if not "puppet" in arm_name:
# # if not "master" in arm_name:
# continue
# if 'topic_name' not in arm_config:
# rospy.logwarn(f"No puppet topic defined for arm {arm_name}")
# continue
# self.puppet_arm_publishers[arm_name] = rospy.Publisher(
# arm_config['topic_name'],
# JointState,
# queue_size=10
# )
# # Wait for publisher to connect
# rospy.sleep(0.1)
# rospy.loginfo("Initialized puppet arm publishers")
# def get_arguments() -> argparse.Namespace:
# """获取运行时参数"""
# parser = argparse.ArgumentParser()
# parser.add_argument('--fps', type=int, help='Frame rate', default=30)
# parser.add_argument('--max_timesteps', type=int, help='Max timesteps', default=500)
# parser.add_argument('--episode_idx', type=int, help='Episode index', default=0)
# parser.add_argument('--use_depth', action='store_true', help='Use depth images')
# parser.add_argument('--use_base', action='store_true', help='Use robot base')
# return parser.parse_args()
# if __name__ == "__main__":
# # 示例用法
# import json
# args = get_arguments()
# robot = AgilexRobot(config_file="/home/jgl20/LYT/work/collect_data_lerobot/1.yaml", args=args)
# print(json.dumps(robot.features, indent=4))
# robot.warmup_record()
# count = 0
# print_flag = True
# rate = rospy.Rate(args.fps)
# while (count < args.max_timesteps + 1) and not rospy.is_shutdown():
# a, b = robot.teleop_step()
# if a is None or b is None:
# if print_flag:
# print("syn fail\n")
# print_flag = False
# rate.sleep()
# continue
# else:
# print(a)
# timesteps, actions = robot.process()
# print(timesteps)
print()

View File

@@ -408,7 +408,7 @@ def get_arguments():
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
args.root = "./data5"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
@@ -429,51 +429,26 @@ def get_arguments():
# @parser.wrap()
# def control_robot(cfg: ControlPipelineConfig):
# init_logging()
# logging.info(pformat(asdict(cfg)))
# # robot = make_robot_from_config(cfg.robot)
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# if isinstance(cfg.control, RecordControlConfig):
# print(cfg.control)
# record(robot, cfg.control)
# elif isinstance(cfg.control, ReplayControlConfig):
# replay(robot, cfg.control)
# # if robot.is_connected:
# # # Disconnect manually to avoid a "Core dump" during process
# # # termination due to camera threads not properly exiting.
# # robot.disconnect()
# @parser.wrap()
def control_robot(cfg):
# robot = make_robot_from_config(cfg.robot)
from agilex_robot import AgilexRobot
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
from rosrobot_factory import RobotFactory
# 使用工厂模式创建机器人实例
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
# if robot.is_connected:
# # Disconnect manually to avoid a "Core dump" during process
# # termination due to camera threads not properly exiting.
# robot.disconnect()
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# cfg = get_arguments()
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# 使用工厂模式创建机器人实例
# robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)

View File

@@ -1,372 +0,0 @@
import yaml
from collections import deque
import rospy
from cv_bridge import CvBridge
from typing import Dict, Any, Optional, List
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
import argparse
class Robot:
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
"""
机器人基类,处理通用初始化逻辑
Args:
config_file: YAML配置文件路径
args: 运行时参数
"""
self._load_config(config_file)
self._merge_runtime_args(args)
self._init_components()
self._init_data_structures()
self.init_ros()
self.init_features()
self.warmup()
def _load_config(self, config_file: str) -> None:
"""加载YAML配置文件"""
with open(config_file, 'r') as f:
self.config = yaml.safe_load(f)
def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
"""合并运行时参数到配置"""
if args is None:
return
runtime_params = {
'frame_rate': getattr(args, 'fps', None),
'max_timesteps': getattr(args, 'max_timesteps', None),
'episode_idx': getattr(args, 'episode_idx', None),
'use_depth_image': getattr(args, 'use_depth_image', None),
'use_robot_base': getattr(args, 'use_base', None),
'video': getattr(args, 'video', False),
'control_type': getattr(args, 'control_type', False),
}
for key, value in runtime_params.items():
if value is not None:
self.config[key] = value
def _init_components(self) -> None:
"""初始化核心组件"""
self.bridge = CvBridge()
self.subscribers = {}
self.publishers = {}
self._validate_config()
def _validate_config(self) -> None:
"""验证配置完整性"""
required_sections = ['cameras', 'arm']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required config section: {section}")
def _init_data_structures(self) -> None:
"""初始化数据结构模板方法"""
# 相机数据
self.cameras = self.config.get('cameras', {})
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
# 深度数据
self.use_depth_image = self.config.get('use_depth_image', False)
if self.use_depth_image:
self.sync_depth_queues = {
name: deque(maxlen=2000)
for name, cam in self.cameras.items()
if 'depth_topic_name' in cam
}
# 机械臂数据
self.arms = self.config.get('arm', {})
if self.config.get('control_type', '') != 'record':
# 如果不是录制模式,则仅初始化从机械臂数据队列
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
else:
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
# 机器人基座数据
self.use_robot_base = self.config.get('use_robot_base', False)
if self.use_robot_base:
self.sync_base_queue = deque(maxlen=2000)
def init_ros(self) -> None:
"""初始化ROS订阅的模板方法"""
rospy.init_node(
f"{self.config.get('ros_node_name', 'generic_robot_node')}",
anonymous=True
)
self._setup_camera_subscribers()
self._setup_arm_subscribers_publishers()
self._setup_base_subscriber()
self._log_ros_status()
def init_features(self):
"""
根据YAML配置自动生成features结构
"""
self.features = {}
# 初始化相机特征
self._init_camera_features()
# 初始化机械臂特征
self._init_state_features()
self._init_action_features()
# 初始化基座特征(如果启用)
if self.use_robot_base:
self._init_base_features()
import pprint
pprint.pprint(self.features, indent=4)
def _init_camera_features(self):
"""处理所有相机特征"""
for cam_name, cam_config in self.cameras.items():
# 普通图像
self.features[f"observation.images.{cam_name}"] = {
"dtype": "video" if self.config.get("video", False) else "image",
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
"names": ["height", "width", "channel"],
# "video_info": {
# "video.fps": cam_config.get("fps", 30.0),
# "video.codec": cam_config.get("codec", "av1"),
# "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"),
# "video.is_depth_map": False,
# "has_audio": False
# }
}
if self.config.get("use_depth_image", False):
self.features[f"observation.images.depth_{cam_name}"] = {
"dtype": "uint16",
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
"names": ["height", "width"],
}
def _init_state_features(self):
state = self.config.get('state', {})
# 状态特征
self.features["observation.state"] = {
"dtype": "float32",
"shape": (len(state.get('motors', "")),),
"names": {"motors": state.get('motors', "")}
}
if self.config.get('velocity'):
velocity = self.config.get('velocity', "")
self.features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(velocity.get('motors', "")),),
"names": {"motors": velocity.get('motors', "")}
}
if self.config.get('effort'):
effort = self.config.get('effort', "")
self.features["observation.effort"] = {
"dtype": "float32",
"shape": (len(effort.get('motors', "")),),
"names": {"motors": effort.get('motors', "")}
}
def _init_action_features(self):
action = self.config.get('action', {})
# 状态特征
self.features["action"] = {
"dtype": "float32",
"shape": (len(action.get('motors', "")),),
"names": {"motors": action.get('motors', "")}
}
def _init_base_features(self):
"""处理基座特征"""
self.features["observation.base_vel"] = {
"dtype": "float32",
"shape": (2,),
"names": ["linear_x", "angular_z"]
}
def _setup_camera_subscribers(self) -> None:
"""设置相机订阅者"""
for cam_name, cam_config in self.cameras.items():
if 'img_topic_name' in cam_config:
self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber(
cam_config['img_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=False),
queue_size=1000,
tcp_nodelay=True
)
if self.use_depth_image and 'depth_topic_name' in cam_config:
self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber(
cam_config['depth_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=True),
queue_size=1000,
tcp_nodelay=True
)
def _setup_arm_subscribers_publishers(self) -> None:
"""设置机械臂订阅者"""
# 当为record模式时,主从机械臂都需要订阅
# 否则只订阅从机械臂,但向主机械臂发布
if self.config.get('control_type', '') == 'record':
for arm_name, arm_config in self.arms.items():
if 'topic_name' in arm_config:
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name),
queue_size=1000,
tcp_nodelay=True
)
else:
for arm_name, arm_config in self.arms.items():
if 'puppet' in arm_name:
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name),
queue_size=1000,
tcp_nodelay=True
)
if 'master' in arm_name:
self.publishers[f"arm_{arm_name}"] = rospy.Publisher(
arm_config['topic_name'],
JointState,
queue_size=10
)
def _setup_base_subscriber(self) -> None:
"""设置基座订阅者"""
if self.use_robot_base and 'robot_base' in self.config:
self.subscribers['base'] = rospy.Subscriber(
self.config['robot_base']['topic_name'],
Odometry,
self.robot_base_callback,
queue_size=1000,
tcp_nodelay=True
)
def _log_ros_status(self) -> None:
"""记录ROS状态"""
rospy.loginfo("\n=== ROS订阅状态 ===")
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
rospy.loginfo("活跃的订阅者:")
for topic, sub in self.subscribers.items():
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
rospy.loginfo("=================")
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
"""生成相机回调函数工厂方法"""
def callback(msg):
try:
target_queue = (
self.sync_depth_queues[cam_name]
if is_depth
else self.sync_img_queues[cam_name]
)
if len(target_queue) >= 2000:
target_queue.popleft()
target_queue.append(msg)
except Exception as e:
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
return callback
def _make_arm_callback(self, arm_name: str):
"""生成机械臂回调函数工厂方法"""
def callback(msg):
try:
if len(self.sync_arm_queues[arm_name]) >= 2000:
self.sync_arm_queues[arm_name].popleft()
self.sync_arm_queues[arm_name].append(msg)
except Exception as e:
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
return callback
def robot_base_callback(self, msg):
"""基座回调默认实现"""
if len(self.sync_base_queue) >= 2000:
self.sync_base_queue.popleft()
self.sync_base_queue.append(msg)
def warmup(self, timeout: float = 10.0) -> bool:
"""Wait until all data queues have at least 20 messages.
Args:
timeout: Maximum time to wait in seconds before giving up
Returns:
bool: True if warmup succeeded, False if timed out
"""
start_time = rospy.Time.now().to_sec()
rate = rospy.Rate(10) # Check at 10Hz
rospy.loginfo("Starting warmup process...")
while not rospy.is_shutdown():
# Check if timeout has been reached
current_time = rospy.Time.now().to_sec()
if current_time - start_time > timeout:
rospy.logwarn("Warmup timed out before all queues were filled")
return False
# Check all required queues
all_ready = True
# Check camera image queues
for cam_name in self.cameras:
if len(self.sync_img_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)")
all_ready = False
break
# Check depth queues if enabled
if self.use_depth_image:
for cam_name in self.sync_depth_queues:
if len(self.sync_depth_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)")
all_ready = False
break
# # Check arm queues
# for arm_name in self.arms:
# if len(self.sync_arm_queues[arm_name]) < 20:
# rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)")
# all_ready = False
# break
# Check base queue if enabled
if self.use_robot_base:
if len(self.sync_base_queue) < 20:
rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)")
all_ready = False
# If all queues are ready, return success
if all_ready:
rospy.loginfo("Warmup completed successfully")
return True
rate.sleep()
return False
def get_frame(self) -> Optional[Dict[str, Any]]:
"""获取同步帧数据的模板方法"""
raise NotImplementedError("Subclasses must implement get_frame()")
def process(self) -> tuple:
"""主处理循环的模板方法"""
raise NotImplementedError("Subclasses must implement process()")

View File

@@ -1,26 +0,0 @@
import yaml
import argparse
from typing import Dict, List, Any, Optional
from rosrobot import Robot
from agilex_robot import AgilexRobot
class RobotFactory:
@staticmethod
def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
"""
根据配置文件自动创建合适的机器人实例
Args:
config_file: 配置文件路径
args: 运行时参数
"""
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
robot_type = config.get('robot_type', 'agilex')
if robot_type == 'agilex':
return AgilexRobot(config_file, args)
# 可扩展其他机器人类型
else:
raise ValueError(f"Unsupported robot type: {robot_type}")