init repo
This commit is contained in:
372
collect_data/rosrobot.py
Normal file
372
collect_data/rosrobot.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import yaml
|
||||
from collections import deque
|
||||
import rospy
|
||||
from cv_bridge import CvBridge
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from nav_msgs.msg import Odometry
|
||||
import argparse
|
||||
|
||||
|
||||
class Robot:
|
||||
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
|
||||
"""
|
||||
机器人基类,处理通用初始化逻辑
|
||||
Args:
|
||||
config_file: YAML配置文件路径
|
||||
args: 运行时参数
|
||||
"""
|
||||
self._load_config(config_file)
|
||||
self._merge_runtime_args(args)
|
||||
self._init_components()
|
||||
self._init_data_structures()
|
||||
self.init_ros()
|
||||
self.init_features()
|
||||
self.warmup()
|
||||
|
||||
def _load_config(self, config_file: str) -> None:
|
||||
"""加载YAML配置文件"""
|
||||
with open(config_file, 'r') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
|
||||
"""合并运行时参数到配置"""
|
||||
if args is None:
|
||||
return
|
||||
|
||||
runtime_params = {
|
||||
'frame_rate': getattr(args, 'fps', None),
|
||||
'max_timesteps': getattr(args, 'max_timesteps', None),
|
||||
'episode_idx': getattr(args, 'episode_idx', None),
|
||||
'use_depth_image': getattr(args, 'use_depth_image', None),
|
||||
'use_robot_base': getattr(args, 'use_base', None),
|
||||
'video': getattr(args, 'video', False),
|
||||
'control_type': getattr(args, 'control_type', False),
|
||||
}
|
||||
|
||||
for key, value in runtime_params.items():
|
||||
if value is not None:
|
||||
self.config[key] = value
|
||||
|
||||
def _init_components(self) -> None:
|
||||
"""初始化核心组件"""
|
||||
self.bridge = CvBridge()
|
||||
self.subscribers = {}
|
||||
self.publishers = {}
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""验证配置完整性"""
|
||||
required_sections = ['cameras', 'arm']
|
||||
for section in required_sections:
|
||||
if section not in self.config:
|
||||
raise ValueError(f"Missing required config section: {section}")
|
||||
|
||||
def _init_data_structures(self) -> None:
|
||||
"""初始化数据结构模板方法"""
|
||||
# 相机数据
|
||||
self.cameras = self.config.get('cameras', {})
|
||||
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
|
||||
|
||||
# 深度数据
|
||||
self.use_depth_image = self.config.get('use_depth_image', False)
|
||||
if self.use_depth_image:
|
||||
self.sync_depth_queues = {
|
||||
name: deque(maxlen=2000)
|
||||
for name, cam in self.cameras.items()
|
||||
if 'depth_topic_name' in cam
|
||||
}
|
||||
|
||||
# 机械臂数据
|
||||
self.arms = self.config.get('arm', {})
|
||||
if self.config.get('control_type', '') != 'record':
|
||||
# 如果不是录制模式,则仅初始化从机械臂数据队列
|
||||
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
|
||||
else:
|
||||
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
|
||||
|
||||
# 机器人基座数据
|
||||
self.use_robot_base = self.config.get('use_robot_base', False)
|
||||
if self.use_robot_base:
|
||||
self.sync_base_queue = deque(maxlen=2000)
|
||||
|
||||
def init_ros(self) -> None:
|
||||
"""初始化ROS订阅的模板方法"""
|
||||
rospy.init_node(
|
||||
f"{self.config.get('ros_node_name', 'generic_robot_node')}",
|
||||
anonymous=True
|
||||
)
|
||||
|
||||
self._setup_camera_subscribers()
|
||||
self._setup_arm_subscribers_publishers()
|
||||
self._setup_base_subscriber()
|
||||
self._log_ros_status()
|
||||
|
||||
def init_features(self):
|
||||
"""
|
||||
根据YAML配置自动生成features结构
|
||||
"""
|
||||
self.features = {}
|
||||
|
||||
# 初始化相机特征
|
||||
self._init_camera_features()
|
||||
|
||||
# 初始化机械臂特征
|
||||
self._init_state_features()
|
||||
|
||||
self._init_action_features()
|
||||
|
||||
# 初始化基座特征(如果启用)
|
||||
if self.use_robot_base:
|
||||
self._init_base_features()
|
||||
import pprint
|
||||
pprint.pprint(self.features, indent=4)
|
||||
|
||||
|
||||
def _init_camera_features(self):
|
||||
"""处理所有相机特征"""
|
||||
for cam_name, cam_config in self.cameras.items():
|
||||
# 普通图像
|
||||
self.features[f"observation.images.{cam_name}"] = {
|
||||
"dtype": "video" if self.config.get("video", False) else "image",
|
||||
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
|
||||
"names": ["height", "width", "channel"],
|
||||
# "video_info": {
|
||||
# "video.fps": cam_config.get("fps", 30.0),
|
||||
# "video.codec": cam_config.get("codec", "av1"),
|
||||
# "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"),
|
||||
# "video.is_depth_map": False,
|
||||
# "has_audio": False
|
||||
# }
|
||||
}
|
||||
|
||||
if self.config.get("use_depth_image", False):
|
||||
self.features[f"observation.images.depth_{cam_name}"] = {
|
||||
"dtype": "uint16",
|
||||
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
|
||||
"names": ["height", "width"],
|
||||
}
|
||||
|
||||
|
||||
def _init_state_features(self):
|
||||
state = self.config.get('state', {})
|
||||
# 状态特征
|
||||
self.features["observation.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state.get('motors', "")),),
|
||||
"names": {"motors": state.get('motors', "")}
|
||||
}
|
||||
|
||||
if self.config.get('velocity'):
|
||||
velocity = self.config.get('velocity', "")
|
||||
self.features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(velocity.get('motors', "")),),
|
||||
"names": {"motors": velocity.get('motors', "")}
|
||||
}
|
||||
|
||||
if self.config.get('effort'):
|
||||
effort = self.config.get('effort', "")
|
||||
self.features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(effort.get('motors', "")),),
|
||||
"names": {"motors": effort.get('motors', "")}
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _init_action_features(self):
|
||||
action = self.config.get('action', {})
|
||||
# 状态特征
|
||||
self.features["action"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action.get('motors', "")),),
|
||||
"names": {"motors": action.get('motors', "")}
|
||||
}
|
||||
|
||||
def _init_base_features(self):
|
||||
"""处理基座特征"""
|
||||
self.features["observation.base_vel"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["linear_x", "angular_z"]
|
||||
}
|
||||
|
||||
|
||||
def _setup_camera_subscribers(self) -> None:
|
||||
"""设置相机订阅者"""
|
||||
for cam_name, cam_config in self.cameras.items():
|
||||
if 'img_topic_name' in cam_config:
|
||||
self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber(
|
||||
cam_config['img_topic_name'],
|
||||
Image,
|
||||
self._make_camera_callback(cam_name, is_depth=False),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
|
||||
if self.use_depth_image and 'depth_topic_name' in cam_config:
|
||||
self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber(
|
||||
cam_config['depth_topic_name'],
|
||||
Image,
|
||||
self._make_camera_callback(cam_name, is_depth=True),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
|
||||
def _setup_arm_subscribers_publishers(self) -> None:
|
||||
"""设置机械臂订阅者"""
|
||||
# 当为record模式时,主从机械臂都需要订阅
|
||||
# 否则只订阅从机械臂,但向主机械臂发布
|
||||
if self.config.get('control_type', '') == 'record':
|
||||
for arm_name, arm_config in self.arms.items():
|
||||
if 'topic_name' in arm_config:
|
||||
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
|
||||
arm_config['topic_name'],
|
||||
JointState,
|
||||
self._make_arm_callback(arm_name),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
else:
|
||||
for arm_name, arm_config in self.arms.items():
|
||||
if 'puppet' in arm_name:
|
||||
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
|
||||
arm_config['topic_name'],
|
||||
JointState,
|
||||
self._make_arm_callback(arm_name),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
if 'master' in arm_name:
|
||||
self.publishers[f"arm_{arm_name}"] = rospy.Publisher(
|
||||
arm_config['topic_name'],
|
||||
JointState,
|
||||
queue_size=10
|
||||
)
|
||||
|
||||
def _setup_base_subscriber(self) -> None:
|
||||
"""设置基座订阅者"""
|
||||
if self.use_robot_base and 'robot_base' in self.config:
|
||||
self.subscribers['base'] = rospy.Subscriber(
|
||||
self.config['robot_base']['topic_name'],
|
||||
Odometry,
|
||||
self.robot_base_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
|
||||
def _log_ros_status(self) -> None:
|
||||
"""记录ROS状态"""
|
||||
rospy.loginfo("\n=== ROS订阅状态 ===")
|
||||
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
|
||||
rospy.loginfo("活跃的订阅者:")
|
||||
for topic, sub in self.subscribers.items():
|
||||
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
|
||||
rospy.loginfo("=================")
|
||||
|
||||
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
|
||||
"""生成相机回调函数工厂方法"""
|
||||
def callback(msg):
|
||||
try:
|
||||
target_queue = (
|
||||
self.sync_depth_queues[cam_name]
|
||||
if is_depth
|
||||
else self.sync_img_queues[cam_name]
|
||||
)
|
||||
if len(target_queue) >= 2000:
|
||||
target_queue.popleft()
|
||||
target_queue.append(msg)
|
||||
except Exception as e:
|
||||
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
|
||||
return callback
|
||||
|
||||
def _make_arm_callback(self, arm_name: str):
|
||||
"""生成机械臂回调函数工厂方法"""
|
||||
def callback(msg):
|
||||
try:
|
||||
if len(self.sync_arm_queues[arm_name]) >= 2000:
|
||||
self.sync_arm_queues[arm_name].popleft()
|
||||
self.sync_arm_queues[arm_name].append(msg)
|
||||
except Exception as e:
|
||||
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
|
||||
return callback
|
||||
|
||||
def robot_base_callback(self, msg):
|
||||
"""基座回调默认实现"""
|
||||
if len(self.sync_base_queue) >= 2000:
|
||||
self.sync_base_queue.popleft()
|
||||
self.sync_base_queue.append(msg)
|
||||
|
||||
def warmup(self, timeout: float = 10.0) -> bool:
|
||||
"""Wait until all data queues have at least 20 messages.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds before giving up
|
||||
|
||||
Returns:
|
||||
bool: True if warmup succeeded, False if timed out
|
||||
"""
|
||||
start_time = rospy.Time.now().to_sec()
|
||||
rate = rospy.Rate(10) # Check at 10Hz
|
||||
|
||||
rospy.loginfo("Starting warmup process...")
|
||||
|
||||
while not rospy.is_shutdown():
|
||||
# Check if timeout has been reached
|
||||
current_time = rospy.Time.now().to_sec()
|
||||
if current_time - start_time > timeout:
|
||||
rospy.logwarn("Warmup timed out before all queues were filled")
|
||||
return False
|
||||
|
||||
# Check all required queues
|
||||
all_ready = True
|
||||
|
||||
# Check camera image queues
|
||||
for cam_name in self.cameras:
|
||||
if len(self.sync_img_queues[cam_name]) < 50:
|
||||
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)")
|
||||
all_ready = False
|
||||
break
|
||||
|
||||
# Check depth queues if enabled
|
||||
if self.use_depth_image:
|
||||
for cam_name in self.sync_depth_queues:
|
||||
if len(self.sync_depth_queues[cam_name]) < 50:
|
||||
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)")
|
||||
all_ready = False
|
||||
break
|
||||
|
||||
# # Check arm queues
|
||||
# for arm_name in self.arms:
|
||||
# if len(self.sync_arm_queues[arm_name]) < 20:
|
||||
# rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)")
|
||||
# all_ready = False
|
||||
# break
|
||||
|
||||
# Check base queue if enabled
|
||||
if self.use_robot_base:
|
||||
if len(self.sync_base_queue) < 20:
|
||||
rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)")
|
||||
all_ready = False
|
||||
|
||||
# If all queues are ready, return success
|
||||
if all_ready:
|
||||
rospy.loginfo("Warmup completed successfully")
|
||||
return True
|
||||
|
||||
rate.sleep()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_frame(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取同步帧数据的模板方法"""
|
||||
raise NotImplementedError("Subclasses must implement get_frame()")
|
||||
|
||||
def process(self) -> tuple:
|
||||
"""主处理循环的模板方法"""
|
||||
raise NotImplementedError("Subclasses must implement process()")
|
||||
Reference in New Issue
Block a user