422 lines
16 KiB
Python
422 lines
16 KiB
Python
import yaml
|
|
from collections import deque
|
|
import rospy
|
|
from cv_bridge import CvBridge
|
|
from typing import Dict, Any, Optional, List
|
|
from sensor_msgs.msg import Image, JointState
|
|
from nav_msgs.msg import Odometry
|
|
import argparse
|
|
|
|
|
|
class RobotConfig:
|
|
"""Configuration management for robot components"""
|
|
|
|
def __init__(self, config_file: str):
|
|
"""
|
|
Initialize robot configuration from YAML file
|
|
|
|
Args:
|
|
config_file: Path to YAML configuration file
|
|
"""
|
|
self.config = self._load_yaml(config_file)
|
|
self._validate_config()
|
|
|
|
def _load_yaml(self, config_file: str) -> Dict[str, Any]:
|
|
"""Load configuration from YAML file"""
|
|
with open(config_file, 'r') as f:
|
|
return yaml.safe_load(f)
|
|
|
|
def _validate_config(self) -> None:
|
|
"""Validate configuration completeness"""
|
|
required_sections = ['cameras', 'arm']
|
|
for section in required_sections:
|
|
if section not in self.config:
|
|
raise ValueError(f"Missing required config section: {section}")
|
|
|
|
def merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
|
|
"""
|
|
Merge runtime arguments into configuration
|
|
|
|
Args:
|
|
args: Runtime arguments from command line
|
|
"""
|
|
if args is None:
|
|
return
|
|
|
|
runtime_params = {
|
|
'frame_rate': getattr(args, 'fps', None),
|
|
'max_timesteps': getattr(args, 'max_timesteps', None),
|
|
'episode_idx': getattr(args, 'episode_idx', None),
|
|
'use_depth_image': getattr(args, 'use_depth_image', None),
|
|
'use_robot_base': getattr(args, 'use_base', None),
|
|
'video': getattr(args, 'video', False),
|
|
'control_type': getattr(args, 'control_type', False),
|
|
}
|
|
|
|
for key, value in runtime_params.items():
|
|
if value is not None:
|
|
self.config[key] = value
|
|
|
|
def get(self, key: str, default=None) -> Any:
|
|
"""Get configuration value with default fallback"""
|
|
return self.config.get(key, default)
|
|
|
|
|
|
class RosAdapter:
|
|
"""Adapter for ROS communication"""
|
|
|
|
def __init__(self, config: RobotConfig):
|
|
"""
|
|
Initialize ROS adapter
|
|
|
|
Args:
|
|
config: Robot configuration
|
|
"""
|
|
self.config = config
|
|
self.bridge = CvBridge()
|
|
self.subscribers = {}
|
|
self.publishers = {}
|
|
|
|
def init_ros_node(self, node_name: str = None) -> None:
|
|
"""Initialize ROS node"""
|
|
if node_name is None:
|
|
node_name = self.config.get('ros_node_name', 'generic_robot_node')
|
|
|
|
rospy.init_node(node_name, anonymous=True)
|
|
|
|
def create_subscriber(self, topic: str, msg_type, callback, queue_size: int = 1000, tcp_nodelay: bool = True):
|
|
"""Create a ROS subscriber"""
|
|
subscriber = rospy.Subscriber(
|
|
topic,
|
|
msg_type,
|
|
callback,
|
|
queue_size=queue_size,
|
|
tcp_nodelay=tcp_nodelay
|
|
)
|
|
return subscriber
|
|
|
|
def create_publisher(self, topic: str, msg_type, queue_size: int = 10):
|
|
"""Create a ROS publisher"""
|
|
publisher = rospy.Publisher(
|
|
topic,
|
|
msg_type,
|
|
queue_size=queue_size
|
|
)
|
|
return publisher
|
|
|
|
def log_status(self) -> None:
|
|
"""Log ROS connection status"""
|
|
rospy.loginfo("\n=== ROS订阅状态 ===")
|
|
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
|
|
rospy.loginfo("活跃的订阅者:")
|
|
for topic, sub in self.subscribers.items():
|
|
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
|
|
rospy.loginfo("=================")
|
|
|
|
|
|
class RobotSensors:
|
|
"""Management of robot sensors (cameras, depth sensors)"""
|
|
|
|
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
|
|
"""
|
|
Initialize robot sensors
|
|
|
|
Args:
|
|
config: Robot configuration
|
|
ros_adapter: ROS communication adapter
|
|
"""
|
|
self.config = config
|
|
self.ros_adapter = ros_adapter
|
|
self.bridge = ros_adapter.bridge
|
|
|
|
# Camera data
|
|
self.cameras = config.get('cameras', {})
|
|
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
|
|
|
|
# Depth data
|
|
self.use_depth_image = config.get('use_depth_image', False)
|
|
if self.use_depth_image:
|
|
self.sync_depth_queues = {
|
|
name: deque(maxlen=2000)
|
|
for name, cam in self.cameras.items()
|
|
if 'depth_topic_name' in cam
|
|
}
|
|
|
|
# Robot base data
|
|
self.use_robot_base = config.get('use_robot_base', False)
|
|
if self.use_robot_base:
|
|
self.sync_base_queue = deque(maxlen=2000)
|
|
|
|
def setup_subscribers(self) -> None:
|
|
"""Set up ROS subscribers for sensors"""
|
|
self._setup_camera_subscribers()
|
|
if self.use_robot_base:
|
|
self._setup_base_subscriber()
|
|
|
|
def _setup_camera_subscribers(self) -> None:
|
|
"""Set up camera subscribers"""
|
|
for cam_name, cam_config in self.cameras.items():
|
|
if 'img_topic_name' in cam_config:
|
|
self.ros_adapter.subscribers[f"camera_{cam_name}"] = self.ros_adapter.create_subscriber(
|
|
cam_config['img_topic_name'],
|
|
Image,
|
|
self._make_camera_callback(cam_name, is_depth=False)
|
|
)
|
|
|
|
if self.use_depth_image and 'depth_topic_name' in cam_config:
|
|
self.ros_adapter.subscribers[f"depth_{cam_name}"] = self.ros_adapter.create_subscriber(
|
|
cam_config['depth_topic_name'],
|
|
Image,
|
|
self._make_camera_callback(cam_name, is_depth=True)
|
|
)
|
|
|
|
def _setup_base_subscriber(self) -> None:
|
|
"""Set up base subscriber"""
|
|
if 'robot_base' in self.config.config:
|
|
self.ros_adapter.subscribers['base'] = self.ros_adapter.create_subscriber(
|
|
self.config.get('robot_base')['topic_name'],
|
|
Odometry,
|
|
self.robot_base_callback
|
|
)
|
|
|
|
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
|
|
"""Generate camera callback factory method"""
|
|
def callback(msg):
|
|
try:
|
|
target_queue = (
|
|
self.sync_depth_queues[cam_name]
|
|
if is_depth
|
|
else self.sync_img_queues[cam_name]
|
|
)
|
|
if len(target_queue) >= 2000:
|
|
target_queue.popleft()
|
|
target_queue.append(msg)
|
|
except Exception as e:
|
|
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
|
|
return callback
|
|
|
|
def robot_base_callback(self, msg):
|
|
"""Base callback default implementation"""
|
|
if len(self.sync_base_queue) >= 2000:
|
|
self.sync_base_queue.popleft()
|
|
self.sync_base_queue.append(msg)
|
|
|
|
def init_features(self) -> Dict[str, Any]:
|
|
"""Initialize sensor features"""
|
|
features = {}
|
|
|
|
# Initialize camera features
|
|
self._init_camera_features(features)
|
|
|
|
# Initialize base features (if enabled)
|
|
if self.use_robot_base:
|
|
self._init_base_features(features)
|
|
|
|
return features
|
|
|
|
def _init_camera_features(self, features: Dict[str, Any]) -> None:
|
|
"""Process all camera features"""
|
|
for cam_name, cam_config in self.cameras.items():
|
|
# Regular images
|
|
features[f"observation.images.{cam_name}"] = {
|
|
"dtype": "video" if self.config.get("video", False) else "image",
|
|
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
|
|
"names": ["height", "width", "channel"],
|
|
}
|
|
|
|
if self.config.get("use_depth_image", False):
|
|
features[f"observation.images.depth_{cam_name}"] = {
|
|
"dtype": "uint16",
|
|
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
|
|
"names": ["height", "width"],
|
|
}
|
|
|
|
def _init_base_features(self, features: Dict[str, Any]) -> None:
|
|
"""Process base features"""
|
|
features["observation.base_vel"] = {
|
|
"dtype": "float32",
|
|
"shape": (2,),
|
|
"names": ["linear_x", "angular_z"]
|
|
}
|
|
|
|
|
|
class RobotActuators:
|
|
"""Management of robot actuators (arms, base)"""
|
|
|
|
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
|
|
"""
|
|
Initialize robot actuators
|
|
|
|
Args:
|
|
config: Robot configuration
|
|
ros_adapter: ROS communication adapter
|
|
"""
|
|
self.config = config
|
|
self.ros_adapter = ros_adapter
|
|
|
|
# Arm data
|
|
self.arms = config.get('arm', {})
|
|
if config.get('control_type', '') != 'record':
|
|
# If not in record mode, only initialize puppet arm queues
|
|
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
|
|
else:
|
|
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
|
|
|
|
def setup_subscribers_publishers(self) -> None:
|
|
"""Set up ROS subscribers and publishers for actuators"""
|
|
self._setup_arm_subscribers_publishers()
|
|
|
|
def _setup_arm_subscribers_publishers(self) -> None:
|
|
"""Set up arm subscribers and publishers"""
|
|
# When in record mode, subscribe to both master and puppet arms
|
|
# Otherwise only subscribe to puppet arms, but publish to master arms
|
|
if self.config.get('control_type', '') == 'record':
|
|
for arm_name, arm_config in self.arms.items():
|
|
if 'topic_name' in arm_config:
|
|
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
|
|
arm_config['topic_name'],
|
|
JointState,
|
|
self._make_arm_callback(arm_name)
|
|
)
|
|
else:
|
|
for arm_name, arm_config in self.arms.items():
|
|
if 'puppet' in arm_name:
|
|
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
|
|
arm_config['topic_name'],
|
|
JointState,
|
|
self._make_arm_callback(arm_name)
|
|
)
|
|
if 'master' in arm_name:
|
|
self.ros_adapter.publishers[f"arm_{arm_name}"] = self.ros_adapter.create_publisher(
|
|
arm_config['topic_name'],
|
|
JointState
|
|
)
|
|
|
|
def _make_arm_callback(self, arm_name: str):
|
|
"""Generate arm callback factory method"""
|
|
def callback(msg):
|
|
try:
|
|
if len(self.sync_arm_queues[arm_name]) >= 2000:
|
|
self.sync_arm_queues[arm_name].popleft()
|
|
self.sync_arm_queues[arm_name].append(msg)
|
|
except Exception as e:
|
|
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
|
|
return callback
|
|
|
|
def init_features(self) -> Dict[str, Any]:
|
|
"""Initialize actuator features"""
|
|
features = {}
|
|
|
|
# Initialize arm features
|
|
self._init_state_features(features)
|
|
self._init_action_features(features)
|
|
|
|
return features
|
|
|
|
def _init_state_features(self, features: Dict[str, Any]) -> None:
|
|
"""Initialize state features"""
|
|
state = self.config.get('state', {})
|
|
# State features
|
|
features["observation.state"] = {
|
|
"dtype": "float32",
|
|
"shape": (len(state.get('motors', "")),),
|
|
"names": {"motors": state.get('motors', "")}
|
|
}
|
|
|
|
# if self.config.get('velocity'):
|
|
# velocity = self.config.get('velocity', "")
|
|
# features["observation.velocity"] = {
|
|
# "dtype": "float32",
|
|
# "shape": (len(velocity.get('motors', "")),),
|
|
# "names": {"motors": velocity.get('motors', "")}
|
|
# }
|
|
|
|
# if self.config.get('effort'):
|
|
# effort = self.config.get('effort', "")
|
|
# features["observation.effort"] = {
|
|
# "dtype": "float32",
|
|
# "shape": (len(effort.get('motors', "")),),
|
|
# "names": {"motors": effort.get('motors', "")}
|
|
# }
|
|
|
|
def _init_action_features(self, features: Dict[str, Any]) -> None:
|
|
"""Initialize action features"""
|
|
action = self.config.get('action', {})
|
|
features["action"] = {
|
|
"dtype": "float32",
|
|
"shape": (len(action.get('motors', "")),),
|
|
"names": {"motors": action.get('motors', "")}
|
|
}
|
|
|
|
|
|
class RobotDataManager:
|
|
"""Management of robot data collection and synchronization"""
|
|
|
|
def __init__(self, config: RobotConfig, sensors: RobotSensors, actuators: RobotActuators):
|
|
"""
|
|
Initialize robot data manager
|
|
|
|
Args:
|
|
config: Robot configuration
|
|
sensors: Robot sensors component
|
|
actuators: Robot actuators component
|
|
"""
|
|
self.config = config
|
|
self.sensors = sensors
|
|
self.actuators = actuators
|
|
|
|
def warmup(self, timeout: float = 30.0) -> bool:
|
|
"""
|
|
Wait until all data queues have sufficient messages
|
|
|
|
Args:
|
|
timeout: Maximum time to wait in seconds
|
|
|
|
Returns:
|
|
bool: True if warmup succeeded, False if timed out
|
|
"""
|
|
start_time = rospy.Time.now().to_sec()
|
|
rate = rospy.Rate(10) # Check at 10Hz
|
|
|
|
rospy.loginfo("Starting warmup process...")
|
|
|
|
while not rospy.is_shutdown():
|
|
# Check if timeout has been reached
|
|
current_time = rospy.Time.now().to_sec()
|
|
if current_time - start_time > timeout:
|
|
rospy.logwarn("Warmup timed out before all queues were filled")
|
|
return False
|
|
|
|
# Check all required queues
|
|
all_ready = True
|
|
|
|
# Check camera image queues
|
|
rospy.loginfo(f"Nums of camera is {len(self.sensors.cameras)}")
|
|
for cam_name in self.sensors.cameras:
|
|
if len(self.sensors.sync_img_queues[cam_name]) < 200:
|
|
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)")
|
|
all_ready = False
|
|
break
|
|
|
|
# Check depth queues if enabled
|
|
if self.sensors.use_depth_image:
|
|
for cam_name in self.sensors.sync_depth_queues:
|
|
if len(self.sensors.sync_depth_queues[cam_name]) < 200:
|
|
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)")
|
|
all_ready = False
|
|
break
|
|
|
|
# Check base queue if enabled
|
|
if self.sensors.use_robot_base:
|
|
if len(self.sensors.sync_base_queue) < 20:
|
|
rospy.loginfo(f"Waiting for base (current: {len(self.sensors.sync_base_queue)}/20)")
|
|
all_ready = False
|
|
|
|
# If all queues are ready, return success
|
|
if all_ready:
|
|
rospy.loginfo("Warmup completed successfully")
|
|
return True
|
|
|
|
rate.sleep()
|
|
|
|
return False |