modify code structure
This commit is contained in:
Binary file not shown.
BIN
collect_data/__pycache__/robot_components.cpython-310.pyc
Normal file
BIN
collect_data/__pycache__/robot_components.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
collect_data/__pycache__/rosrobot_factory.cpython-310.pyc
Normal file
BIN
collect_data/__pycache__/rosrobot_factory.cpython-310.pyc
Normal file
Binary file not shown.
@@ -19,6 +19,12 @@ cameras:
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
cam_high:
|
||||
img_topic_name: /camera/color/image_raw
|
||||
depth_topic_name: /camera/depth/image_rect_raw
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
|
||||
arm:
|
||||
master_left:
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
import yaml
|
||||
import cv2
|
||||
import numpy as np
|
||||
import collections
|
||||
import dm_env
|
||||
import argparse
|
||||
from typing import Dict, List, Any, Optional
|
||||
from collections import deque
|
||||
import rospy
|
||||
from cv_bridge import CvBridge
|
||||
from std_msgs.msg import Header
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from nav_msgs.msg import Odometry
|
||||
from rosrobot import Robot
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class AgilexRobot(Robot):
|
||||
def get_frame(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取同步帧数据
|
||||
返回: 包含同步数据的字典,或None如果同步失败
|
||||
"""
|
||||
# 检查基本数据可用性
|
||||
# print(self.sync_img_queues.values())
|
||||
if any(len(q) == 0 for q in self.sync_img_queues.values()):
|
||||
print("camera has not get image data")
|
||||
return None
|
||||
|
||||
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
|
||||
return None
|
||||
|
||||
# print(self.sync_arm_queues.values())
|
||||
# if any(len(q) == 0 for q in self.sync_arm_queues.values()):
|
||||
# print("2")
|
||||
# if len(self.sync_arm_queues['master_left']) == 0:
|
||||
# print("can not get data from master topic")
|
||||
# if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
|
||||
# print("can not get data from puppet topic")
|
||||
# return None
|
||||
|
||||
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
|
||||
print("can not get data from puppet topic")
|
||||
return None
|
||||
|
||||
# 计算最小时间戳
|
||||
timestamps = [
|
||||
q[-1].header.stamp.to_sec()
|
||||
for q in list(self.sync_img_queues.values()) +
|
||||
list(self.sync_arm_queues.values())
|
||||
]
|
||||
|
||||
if self.use_depth_image:
|
||||
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
|
||||
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
|
||||
|
||||
min_time = min(timestamps)
|
||||
|
||||
# 检查数据同步性
|
||||
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
||||
if queue[-1].header.stamp.to_sec() < min_time:
|
||||
return None
|
||||
|
||||
if self.use_depth_image:
|
||||
for queue in self.sync_depth_queues.values():
|
||||
if queue[-1].header.stamp.to_sec() < min_time:
|
||||
return None
|
||||
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
|
||||
return None
|
||||
|
||||
# 提取同步数据
|
||||
frame_data = {
|
||||
'images': {},
|
||||
'arms': {},
|
||||
'timestamp': min_time
|
||||
}
|
||||
|
||||
# 图像数据
|
||||
for cam_name, queue in self.sync_img_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
queue.popleft()
|
||||
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
|
||||
# 深度数据
|
||||
if self.use_depth_image:
|
||||
frame_data['depths'] = {}
|
||||
for cam_name, queue in self.sync_depth_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
queue.popleft()
|
||||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
# 保持原有的边界填充
|
||||
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
|
||||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||||
)
|
||||
|
||||
# 机械臂数据
|
||||
for arm_name, queue in self.sync_arm_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
queue.popleft()
|
||||
frame_data['arms'][arm_name] = queue.popleft()
|
||||
|
||||
# 基座数据
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
|
||||
self.sync_base_queue.popleft()
|
||||
frame_data['base'] = self.sync_base_queue.popleft()
|
||||
|
||||
return frame_data
|
||||
|
||||
|
||||
def teleop_step(self) -> Optional[tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]:
|
||||
"""
|
||||
获取同步帧数据,输出格式与 teleop_step 一致
|
||||
返回: (obs_dict, action_dict) 或 None 如果同步失败
|
||||
"""
|
||||
# 检查基本数据可用性
|
||||
if any(len(q) == 0 for q in self.sync_img_queues.values()):
|
||||
return None, None
|
||||
|
||||
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
|
||||
return None, None
|
||||
|
||||
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
|
||||
return None, None
|
||||
|
||||
# 计算最小时间戳
|
||||
timestamps = [
|
||||
q[-1].header.stamp.to_sec()
|
||||
for q in list(self.sync_img_queues.values()) +
|
||||
list(self.sync_arm_queues.values())
|
||||
]
|
||||
|
||||
if self.use_depth_image:
|
||||
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
|
||||
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
|
||||
|
||||
min_time = min(timestamps)
|
||||
|
||||
# 检查数据同步性
|
||||
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
||||
if queue[-1].header.stamp.to_sec() < min_time:
|
||||
return None, None
|
||||
|
||||
if self.use_depth_image:
|
||||
for queue in self.sync_depth_queues.values():
|
||||
if queue[-1].header.stamp.to_sec() < min_time:
|
||||
return None, None
|
||||
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
|
||||
return None, None
|
||||
|
||||
# 初始化输出字典
|
||||
obs_dict = {}
|
||||
action_dict = {}
|
||||
|
||||
# 处理图像数据
|
||||
for cam_name, queue in self.sync_img_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
queue.popleft()
|
||||
img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(img)
|
||||
|
||||
# 处理深度数据
|
||||
if self.use_depth_image:
|
||||
for cam_name, queue in self.sync_depth_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
queue.popleft()
|
||||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
depth_img = cv2.copyMakeBorder(
|
||||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||||
)
|
||||
obs_dict[f"observation.images.depth_{cam_name}"] = torch.from_numpy(depth_img).unsqueeze(-1)
|
||||
|
||||
# 处理机械臂观测数据
|
||||
arm_states = []
|
||||
arm_velocity = []
|
||||
arm_effort = []
|
||||
actions = []
|
||||
for arm_name, queue in self.sync_arm_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
queue.popleft()
|
||||
arm_data = queue.popleft()
|
||||
|
||||
# np.array(arm_data.position),
|
||||
# np.array(arm_data.velocity),
|
||||
# np.array(arm_data.effort)
|
||||
# 如果是从臂(puppet),作为观测
|
||||
if arm_name.startswith('puppet'):
|
||||
arm_states.append(np.array(arm_data.position, dtype=np.float32))
|
||||
arm_velocity.append(np.array(arm_data.velocity, dtype=np.float32))
|
||||
arm_effort.append(np.array(arm_data.effort, dtype=np.float32))
|
||||
|
||||
# 如果是主臂(master),作为动作
|
||||
if arm_name.startswith('master'):
|
||||
# action_dict[f"action.{arm_name}"] = torch.from_numpy(np.array(arm_data.position))
|
||||
actions.append(np.array(arm_data.position, dtype=np.float32))
|
||||
|
||||
if arm_states:
|
||||
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
|
||||
|
||||
if arm_velocity:
|
||||
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
|
||||
|
||||
if arm_effort:
|
||||
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
|
||||
|
||||
if actions:
|
||||
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
|
||||
# action_dict["action"] = np.concatenate(actions).squeeze()
|
||||
|
||||
# 处理基座数据
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
|
||||
self.sync_base_queue.popleft()
|
||||
base_data = self.sync_base_queue.popleft()
|
||||
obs_dict["observation.base_vel"] = torch.tensor([
|
||||
base_data.twist.twist.linear.x,
|
||||
base_data.twist.twist.angular.z
|
||||
], dtype=torch.float32)
|
||||
|
||||
# 添加时间戳
|
||||
# obs_dict["observation.timestamp"] = torch.tensor(min_time, dtype=torch.float64)
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
|
||||
def capture_observation(self):
|
||||
"""Capture observation data from ROS topics without batch dimension.
|
||||
|
||||
Returns:
|
||||
dict: Observation dictionary containing state and images.
|
||||
|
||||
Raises:
|
||||
RobotDeviceNotConnectedError: If robot is not connected.
|
||||
"""
|
||||
# Initialize observation dictionary
|
||||
obs_dict = {}
|
||||
|
||||
# Get synchronized frame data
|
||||
frame_data = self.get_frame()
|
||||
if frame_data is None:
|
||||
# raise RuntimeError("Failed to capture synchronized observation data")
|
||||
return None
|
||||
|
||||
# Process arm state data (from puppet arms)
|
||||
arm_states = []
|
||||
arm_velocity = []
|
||||
arm_effort = []
|
||||
|
||||
for arm_name, joint_state in frame_data['arms'].items():
|
||||
if arm_name.startswith('puppet'):
|
||||
# Record timing for performance monitoring
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
# Get position data and convert to tensor
|
||||
pos = torch.from_numpy(np.array(joint_state.position, dtype=np.float32))
|
||||
arm_states.append(pos)
|
||||
|
||||
velocity = torch.from_numpy(np.array(joint_state.velocity, dtype=np.float32))
|
||||
arm_velocity.append(velocity)
|
||||
|
||||
effort = torch.from_numpy(np.array(joint_state.effort, dtype=np.float32))
|
||||
arm_effort.append(effort)
|
||||
|
||||
# Log timing information
|
||||
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
|
||||
|
||||
# Combine all arm states into single tensor
|
||||
if arm_states:
|
||||
obs_dict["observation.state"] = torch.cat(arm_states)
|
||||
|
||||
if arm_velocity:
|
||||
obs_dict["observation.velocity"] = torch.cat(arm_velocity)
|
||||
|
||||
if arm_effort:
|
||||
obs_dict["observation.effort"] = torch.cat(arm_effort)
|
||||
|
||||
# Process image data
|
||||
for cam_name, img in frame_data['images'].items():
|
||||
# Record timing for performance monitoring
|
||||
before_camread_t = time.perf_counter()
|
||||
|
||||
# Convert image to tensor
|
||||
img_tensor = torch.from_numpy(img)
|
||||
obs_dict[f"observation.images.{cam_name}"] = img_tensor
|
||||
|
||||
# Log timing information
|
||||
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
|
||||
|
||||
# Process depth data if enabled
|
||||
if self.use_depth_image and 'depths' in frame_data:
|
||||
for cam_name, depth_img in frame_data['depths'].items():
|
||||
before_depthread_t = time.perf_counter()
|
||||
|
||||
# Convert depth image to tensor and add channel dimension
|
||||
depth_tensor = torch.from_numpy(depth_img).unsqueeze(-1)
|
||||
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
|
||||
|
||||
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
|
||||
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
|
||||
|
||||
# Process base velocity if enabled
|
||||
if self.use_robot_base and 'base' in frame_data:
|
||||
base_data = frame_data['base']
|
||||
obs_dict["observation.base_vel"] = torch.tensor([
|
||||
base_data.twist.twist.linear.x,
|
||||
base_data.twist.twist.angular.z
|
||||
], dtype=torch.float32)
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Send joint position commands to the puppet arms via ROS.
|
||||
|
||||
Args:
|
||||
action: Tensor containing concatenated goal positions for all puppet arms
|
||||
Shape should match the action space defined in features["action"]
|
||||
|
||||
Returns:
|
||||
The actual action that was sent (may be clipped if safety checks are implemented)
|
||||
"""
|
||||
# if not hasattr(self, 'puppet_arm_publishers'):
|
||||
# # Initialize publishers on first call
|
||||
# self._init_action_publishers()
|
||||
|
||||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
||||
|
||||
# Convert tensor to numpy array if needed
|
||||
if isinstance(action, torch.Tensor):
|
||||
action = action.detach().cpu().numpy()
|
||||
|
||||
# Split action into individual arm commands based on config
|
||||
from_idx = 0
|
||||
to_idx = 0
|
||||
action_sent = []
|
||||
for arm_name, arm_config in self.arms.items():
|
||||
# 主臂topic是否存在
|
||||
if not "master" in arm_name:
|
||||
continue
|
||||
|
||||
# Get number of joints for this arm
|
||||
num_joints = len(arm_config.get('motors', []))
|
||||
to_idx += num_joints
|
||||
|
||||
# Extract this arm's portion of the action
|
||||
arm_action = action[from_idx:to_idx]
|
||||
arm_velocity = last_velocity[from_idx:to_idx]
|
||||
arm_effort = last_effort[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# Apply safety checks if configured
|
||||
if 'max_relative_target' in self.config:
|
||||
# Get current position from the queue
|
||||
if len(self.sync_arm_queues[arm_name]) > 0:
|
||||
current_state = self.sync_arm_queues[arm_name][-1]
|
||||
current_pos = np.array(current_state.position)
|
||||
|
||||
# Clip the action to stay within max relative target
|
||||
max_delta = self.config['max_relative_target']
|
||||
clipped_action = np.clip(arm_action,
|
||||
current_pos - max_delta,
|
||||
current_pos + max_delta)
|
||||
arm_action = clipped_action
|
||||
|
||||
action_sent.append(arm_action)
|
||||
|
||||
# Create and publish JointState message
|
||||
joint_state = JointState()
|
||||
joint_state.header = Header()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', '']
|
||||
joint_state.position = arm_action.tolist()
|
||||
joint_state.velocity = arm_velocity
|
||||
joint_state.effort = arm_effort
|
||||
|
||||
# Publish to the corresponding topic
|
||||
self.publishers[f"arm_{arm_name}"].publish(joint_state)
|
||||
|
||||
return torch.from_numpy(np.concatenate(action_sent)) if action_sent else torch.tensor([])
|
||||
|
||||
# def _init_action_publishers(self) -> None:
|
||||
# """Initialize ROS publishers for puppet arms"""
|
||||
# self.puppet_arm_publishers = {}
|
||||
# # rospy.init_node("replay_node")
|
||||
# for arm_name, arm_config in self.arms.items():
|
||||
# if not "puppet" in arm_name:
|
||||
# # if not "master" in arm_name:
|
||||
# continue
|
||||
|
||||
# if 'topic_name' not in arm_config:
|
||||
# rospy.logwarn(f"No puppet topic defined for arm {arm_name}")
|
||||
# continue
|
||||
|
||||
# self.puppet_arm_publishers[arm_name] = rospy.Publisher(
|
||||
# arm_config['topic_name'],
|
||||
# JointState,
|
||||
# queue_size=10
|
||||
# )
|
||||
|
||||
# # Wait for publisher to connect
|
||||
# rospy.sleep(0.1)
|
||||
|
||||
# rospy.loginfo("Initialized puppet arm publishers")
|
||||
|
||||
|
||||
|
||||
|
||||
# def get_arguments() -> argparse.Namespace:
|
||||
# """获取运行时参数"""
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--fps', type=int, help='Frame rate', default=30)
|
||||
# parser.add_argument('--max_timesteps', type=int, help='Max timesteps', default=500)
|
||||
# parser.add_argument('--episode_idx', type=int, help='Episode index', default=0)
|
||||
# parser.add_argument('--use_depth', action='store_true', help='Use depth images')
|
||||
# parser.add_argument('--use_base', action='store_true', help='Use robot base')
|
||||
# return parser.parse_args()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# # 示例用法
|
||||
# import json
|
||||
# args = get_arguments()
|
||||
# robot = AgilexRobot(config_file="/home/jgl20/LYT/work/collect_data_lerobot/1.yaml", args=args)
|
||||
# print(json.dumps(robot.features, indent=4))
|
||||
# robot.warmup_record()
|
||||
# count = 0
|
||||
# print_flag = True
|
||||
# rate = rospy.Rate(args.fps)
|
||||
# while (count < args.max_timesteps + 1) and not rospy.is_shutdown():
|
||||
# a, b = robot.teleop_step()
|
||||
|
||||
# if a is None or b is None:
|
||||
# if print_flag:
|
||||
# print("syn fail\n")
|
||||
# print_flag = False
|
||||
# rate.sleep()
|
||||
# continue
|
||||
# else:
|
||||
# print(a)
|
||||
|
||||
|
||||
# timesteps, actions = robot.process()
|
||||
# print(timesteps)
|
||||
print()
|
||||
@@ -408,7 +408,7 @@ def get_arguments():
|
||||
args.fps = 30
|
||||
args.resume = False
|
||||
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
|
||||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
|
||||
args.root = "./data5"
|
||||
args.episode = 0 # replay episode
|
||||
args.num_image_writer_processes = 0
|
||||
args.num_image_writer_threads_per_camera = 4
|
||||
@@ -429,51 +429,26 @@ def get_arguments():
|
||||
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
# def control_robot(cfg: ControlPipelineConfig):
|
||||
# init_logging()
|
||||
# logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# # robot = make_robot_from_config(cfg.robot)
|
||||
# from agilex_robot import AgilexRobot
|
||||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
|
||||
# if isinstance(cfg.control, RecordControlConfig):
|
||||
# print(cfg.control)
|
||||
# record(robot, cfg.control)
|
||||
# elif isinstance(cfg.control, ReplayControlConfig):
|
||||
# replay(robot, cfg.control)
|
||||
|
||||
# # if robot.is_connected:
|
||||
# # # Disconnect manually to avoid a "Core dump" during process
|
||||
# # # termination due to camera threads not properly exiting.
|
||||
# # robot.disconnect()
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
def control_robot(cfg):
|
||||
|
||||
# robot = make_robot_from_config(cfg.robot)
|
||||
from agilex_robot import AgilexRobot
|
||||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
from rosrobot_factory import RobotFactory
|
||||
# 使用工厂模式创建机器人实例
|
||||
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg)
|
||||
|
||||
if cfg.control_type == "record":
|
||||
record(robot, cfg)
|
||||
elif cfg.control_type == "replay":
|
||||
replay(robot, cfg)
|
||||
|
||||
# if robot.is_connected:
|
||||
# # Disconnect manually to avoid a "Core dump" during process
|
||||
# # termination due to camera threads not properly exiting.
|
||||
# robot.disconnect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_arguments()
|
||||
control_robot(cfg)
|
||||
# control_robot()
|
||||
# cfg = get_arguments()
|
||||
# from agilex_robot import AgilexRobot
|
||||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
# 使用工厂模式创建机器人实例
|
||||
# robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
# print(robot.features.items())
|
||||
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
|
||||
# record(robot, cfg)
|
||||
|
||||
@@ -1,372 +0,0 @@
|
||||
import yaml
|
||||
from collections import deque
|
||||
import rospy
|
||||
from cv_bridge import CvBridge
|
||||
from typing import Dict, Any, Optional, List
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from nav_msgs.msg import Odometry
|
||||
import argparse
|
||||
|
||||
|
||||
class Robot:
|
||||
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
|
||||
"""
|
||||
机器人基类,处理通用初始化逻辑
|
||||
Args:
|
||||
config_file: YAML配置文件路径
|
||||
args: 运行时参数
|
||||
"""
|
||||
self._load_config(config_file)
|
||||
self._merge_runtime_args(args)
|
||||
self._init_components()
|
||||
self._init_data_structures()
|
||||
self.init_ros()
|
||||
self.init_features()
|
||||
self.warmup()
|
||||
|
||||
def _load_config(self, config_file: str) -> None:
|
||||
"""加载YAML配置文件"""
|
||||
with open(config_file, 'r') as f:
|
||||
self.config = yaml.safe_load(f)
|
||||
|
||||
def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
|
||||
"""合并运行时参数到配置"""
|
||||
if args is None:
|
||||
return
|
||||
|
||||
runtime_params = {
|
||||
'frame_rate': getattr(args, 'fps', None),
|
||||
'max_timesteps': getattr(args, 'max_timesteps', None),
|
||||
'episode_idx': getattr(args, 'episode_idx', None),
|
||||
'use_depth_image': getattr(args, 'use_depth_image', None),
|
||||
'use_robot_base': getattr(args, 'use_base', None),
|
||||
'video': getattr(args, 'video', False),
|
||||
'control_type': getattr(args, 'control_type', False),
|
||||
}
|
||||
|
||||
for key, value in runtime_params.items():
|
||||
if value is not None:
|
||||
self.config[key] = value
|
||||
|
||||
def _init_components(self) -> None:
|
||||
"""初始化核心组件"""
|
||||
self.bridge = CvBridge()
|
||||
self.subscribers = {}
|
||||
self.publishers = {}
|
||||
self._validate_config()
|
||||
|
||||
def _validate_config(self) -> None:
|
||||
"""验证配置完整性"""
|
||||
required_sections = ['cameras', 'arm']
|
||||
for section in required_sections:
|
||||
if section not in self.config:
|
||||
raise ValueError(f"Missing required config section: {section}")
|
||||
|
||||
def _init_data_structures(self) -> None:
|
||||
"""初始化数据结构模板方法"""
|
||||
# 相机数据
|
||||
self.cameras = self.config.get('cameras', {})
|
||||
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
|
||||
|
||||
# 深度数据
|
||||
self.use_depth_image = self.config.get('use_depth_image', False)
|
||||
if self.use_depth_image:
|
||||
self.sync_depth_queues = {
|
||||
name: deque(maxlen=2000)
|
||||
for name, cam in self.cameras.items()
|
||||
if 'depth_topic_name' in cam
|
||||
}
|
||||
|
||||
# 机械臂数据
|
||||
self.arms = self.config.get('arm', {})
|
||||
if self.config.get('control_type', '') != 'record':
|
||||
# 如果不是录制模式,则仅初始化从机械臂数据队列
|
||||
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
|
||||
else:
|
||||
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
|
||||
|
||||
# 机器人基座数据
|
||||
self.use_robot_base = self.config.get('use_robot_base', False)
|
||||
if self.use_robot_base:
|
||||
self.sync_base_queue = deque(maxlen=2000)
|
||||
|
||||
def init_ros(self) -> None:
|
||||
"""初始化ROS订阅的模板方法"""
|
||||
rospy.init_node(
|
||||
f"{self.config.get('ros_node_name', 'generic_robot_node')}",
|
||||
anonymous=True
|
||||
)
|
||||
|
||||
self._setup_camera_subscribers()
|
||||
self._setup_arm_subscribers_publishers()
|
||||
self._setup_base_subscriber()
|
||||
self._log_ros_status()
|
||||
|
||||
def init_features(self):
|
||||
"""
|
||||
根据YAML配置自动生成features结构
|
||||
"""
|
||||
self.features = {}
|
||||
|
||||
# 初始化相机特征
|
||||
self._init_camera_features()
|
||||
|
||||
# 初始化机械臂特征
|
||||
self._init_state_features()
|
||||
|
||||
self._init_action_features()
|
||||
|
||||
# 初始化基座特征(如果启用)
|
||||
if self.use_robot_base:
|
||||
self._init_base_features()
|
||||
import pprint
|
||||
pprint.pprint(self.features, indent=4)
|
||||
|
||||
|
||||
def _init_camera_features(self):
|
||||
"""处理所有相机特征"""
|
||||
for cam_name, cam_config in self.cameras.items():
|
||||
# 普通图像
|
||||
self.features[f"observation.images.{cam_name}"] = {
|
||||
"dtype": "video" if self.config.get("video", False) else "image",
|
||||
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
|
||||
"names": ["height", "width", "channel"],
|
||||
# "video_info": {
|
||||
# "video.fps": cam_config.get("fps", 30.0),
|
||||
# "video.codec": cam_config.get("codec", "av1"),
|
||||
# "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"),
|
||||
# "video.is_depth_map": False,
|
||||
# "has_audio": False
|
||||
# }
|
||||
}
|
||||
|
||||
if self.config.get("use_depth_image", False):
|
||||
self.features[f"observation.images.depth_{cam_name}"] = {
|
||||
"dtype": "uint16",
|
||||
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
|
||||
"names": ["height", "width"],
|
||||
}
|
||||
|
||||
|
||||
def _init_state_features(self):
|
||||
state = self.config.get('state', {})
|
||||
# 状态特征
|
||||
self.features["observation.state"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state.get('motors', "")),),
|
||||
"names": {"motors": state.get('motors', "")}
|
||||
}
|
||||
|
||||
if self.config.get('velocity'):
|
||||
velocity = self.config.get('velocity', "")
|
||||
self.features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(velocity.get('motors', "")),),
|
||||
"names": {"motors": velocity.get('motors', "")}
|
||||
}
|
||||
|
||||
if self.config.get('effort'):
|
||||
effort = self.config.get('effort', "")
|
||||
self.features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(effort.get('motors', "")),),
|
||||
"names": {"motors": effort.get('motors', "")}
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _init_action_features(self):
|
||||
action = self.config.get('action', {})
|
||||
# 状态特征
|
||||
self.features["action"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action.get('motors', "")),),
|
||||
"names": {"motors": action.get('motors', "")}
|
||||
}
|
||||
|
||||
def _init_base_features(self):
|
||||
"""处理基座特征"""
|
||||
self.features["observation.base_vel"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": ["linear_x", "angular_z"]
|
||||
}
|
||||
|
||||
|
||||
def _setup_camera_subscribers(self) -> None:
|
||||
"""设置相机订阅者"""
|
||||
for cam_name, cam_config in self.cameras.items():
|
||||
if 'img_topic_name' in cam_config:
|
||||
self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber(
|
||||
cam_config['img_topic_name'],
|
||||
Image,
|
||||
self._make_camera_callback(cam_name, is_depth=False),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
|
||||
if self.use_depth_image and 'depth_topic_name' in cam_config:
|
||||
self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber(
|
||||
cam_config['depth_topic_name'],
|
||||
Image,
|
||||
self._make_camera_callback(cam_name, is_depth=True),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
|
||||
def _setup_arm_subscribers_publishers(self) -> None:
|
||||
"""设置机械臂订阅者"""
|
||||
# 当为record模式时,主从机械臂都需要订阅
|
||||
# 否则只订阅从机械臂,但向主机械臂发布
|
||||
if self.config.get('control_type', '') == 'record':
|
||||
for arm_name, arm_config in self.arms.items():
|
||||
if 'topic_name' in arm_config:
|
||||
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
|
||||
arm_config['topic_name'],
|
||||
JointState,
|
||||
self._make_arm_callback(arm_name),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
else:
|
||||
for arm_name, arm_config in self.arms.items():
|
||||
if 'puppet' in arm_name:
|
||||
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
|
||||
arm_config['topic_name'],
|
||||
JointState,
|
||||
self._make_arm_callback(arm_name),
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
if 'master' in arm_name:
|
||||
self.publishers[f"arm_{arm_name}"] = rospy.Publisher(
|
||||
arm_config['topic_name'],
|
||||
JointState,
|
||||
queue_size=10
|
||||
)
|
||||
|
||||
def _setup_base_subscriber(self) -> None:
|
||||
"""设置基座订阅者"""
|
||||
if self.use_robot_base and 'robot_base' in self.config:
|
||||
self.subscribers['base'] = rospy.Subscriber(
|
||||
self.config['robot_base']['topic_name'],
|
||||
Odometry,
|
||||
self.robot_base_callback,
|
||||
queue_size=1000,
|
||||
tcp_nodelay=True
|
||||
)
|
||||
|
||||
def _log_ros_status(self) -> None:
|
||||
"""记录ROS状态"""
|
||||
rospy.loginfo("\n=== ROS订阅状态 ===")
|
||||
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
|
||||
rospy.loginfo("活跃的订阅者:")
|
||||
for topic, sub in self.subscribers.items():
|
||||
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
|
||||
rospy.loginfo("=================")
|
||||
|
||||
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
|
||||
"""生成相机回调函数工厂方法"""
|
||||
def callback(msg):
|
||||
try:
|
||||
target_queue = (
|
||||
self.sync_depth_queues[cam_name]
|
||||
if is_depth
|
||||
else self.sync_img_queues[cam_name]
|
||||
)
|
||||
if len(target_queue) >= 2000:
|
||||
target_queue.popleft()
|
||||
target_queue.append(msg)
|
||||
except Exception as e:
|
||||
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
|
||||
return callback
|
||||
|
||||
def _make_arm_callback(self, arm_name: str):
|
||||
"""生成机械臂回调函数工厂方法"""
|
||||
def callback(msg):
|
||||
try:
|
||||
if len(self.sync_arm_queues[arm_name]) >= 2000:
|
||||
self.sync_arm_queues[arm_name].popleft()
|
||||
self.sync_arm_queues[arm_name].append(msg)
|
||||
except Exception as e:
|
||||
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
|
||||
return callback
|
||||
|
||||
def robot_base_callback(self, msg):
|
||||
"""基座回调默认实现"""
|
||||
if len(self.sync_base_queue) >= 2000:
|
||||
self.sync_base_queue.popleft()
|
||||
self.sync_base_queue.append(msg)
|
||||
|
||||
def warmup(self, timeout: float = 10.0) -> bool:
|
||||
"""Wait until all data queues have at least 20 messages.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds before giving up
|
||||
|
||||
Returns:
|
||||
bool: True if warmup succeeded, False if timed out
|
||||
"""
|
||||
start_time = rospy.Time.now().to_sec()
|
||||
rate = rospy.Rate(10) # Check at 10Hz
|
||||
|
||||
rospy.loginfo("Starting warmup process...")
|
||||
|
||||
while not rospy.is_shutdown():
|
||||
# Check if timeout has been reached
|
||||
current_time = rospy.Time.now().to_sec()
|
||||
if current_time - start_time > timeout:
|
||||
rospy.logwarn("Warmup timed out before all queues were filled")
|
||||
return False
|
||||
|
||||
# Check all required queues
|
||||
all_ready = True
|
||||
|
||||
# Check camera image queues
|
||||
for cam_name in self.cameras:
|
||||
if len(self.sync_img_queues[cam_name]) < 50:
|
||||
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)")
|
||||
all_ready = False
|
||||
break
|
||||
|
||||
# Check depth queues if enabled
|
||||
if self.use_depth_image:
|
||||
for cam_name in self.sync_depth_queues:
|
||||
if len(self.sync_depth_queues[cam_name]) < 50:
|
||||
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)")
|
||||
all_ready = False
|
||||
break
|
||||
|
||||
# # Check arm queues
|
||||
# for arm_name in self.arms:
|
||||
# if len(self.sync_arm_queues[arm_name]) < 20:
|
||||
# rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)")
|
||||
# all_ready = False
|
||||
# break
|
||||
|
||||
# Check base queue if enabled
|
||||
if self.use_robot_base:
|
||||
if len(self.sync_base_queue) < 20:
|
||||
rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)")
|
||||
all_ready = False
|
||||
|
||||
# If all queues are ready, return success
|
||||
if all_ready:
|
||||
rospy.loginfo("Warmup completed successfully")
|
||||
return True
|
||||
|
||||
rate.sleep()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_frame(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取同步帧数据的模板方法"""
|
||||
raise NotImplementedError("Subclasses must implement get_frame()")
|
||||
|
||||
def process(self) -> tuple:
|
||||
"""主处理循环的模板方法"""
|
||||
raise NotImplementedError("Subclasses must implement process()")
|
||||
@@ -1,26 +0,0 @@
|
||||
import yaml
|
||||
import argparse
|
||||
from typing import Dict, List, Any, Optional
|
||||
from rosrobot import Robot
|
||||
from agilex_robot import AgilexRobot
|
||||
|
||||
|
||||
class RobotFactory:
|
||||
@staticmethod
|
||||
def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
|
||||
"""
|
||||
根据配置文件自动创建合适的机器人实例
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
args: 运行时参数
|
||||
"""
|
||||
with open(config_file, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
robot_type = config.get('robot_type', 'agilex')
|
||||
|
||||
if robot_type == 'agilex':
|
||||
return AgilexRobot(config_file, args)
|
||||
# 可扩展其他机器人类型
|
||||
else:
|
||||
raise ValueError(f"Unsupported robot type: {robot_type}")
|
||||
Reference in New Issue
Block a user