Files
lerobot_aloha/collect_data/agilex_robot.py
2025-04-05 21:46:49 +08:00

457 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import yaml
import cv2
import numpy as np
import collections
import dm_env
import argparse
from typing import Dict, List, Any, Optional
from collections import deque
import rospy
from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
from rosrobot import Robot
import torch
import time
class AgilexRobot(Robot):
def get_frame(self) -> Optional[Dict[str, Any]]:
"""
获取同步帧数据
返回: 包含同步数据的字典或None如果同步失败
"""
# 检查基本数据可用性
# print(self.sync_img_queues.values())
if any(len(q) == 0 for q in self.sync_img_queues.values()):
print("camera has not get image data")
return None
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
return None
# print(self.sync_arm_queues.values())
# if any(len(q) == 0 for q in self.sync_arm_queues.values()):
# print("2")
# if len(self.sync_arm_queues['master_left']) == 0:
# print("can not get data from master topic")
# if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
# print("can not get data from puppet topic")
# return None
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
print("can not get data from puppet topic")
return None
# 计算最小时间戳
timestamps = [
q[-1].header.stamp.to_sec()
for q in list(self.sync_img_queues.values()) +
list(self.sync_arm_queues.values())
]
if self.use_depth_image:
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
if self.use_robot_base and len(self.sync_base_queue) > 0:
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
min_time = min(timestamps)
# 检查数据同步性
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
if queue[-1].header.stamp.to_sec() < min_time:
return None
if self.use_depth_image:
for queue in self.sync_depth_queues.values():
if queue[-1].header.stamp.to_sec() < min_time:
return None
if self.use_robot_base and len(self.sync_base_queue) > 0:
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
return None
# 提取同步数据
frame_data = {
'images': {},
'arms': {},
'timestamp': min_time
}
# 图像数据
for cam_name, queue in self.sync_img_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
# 深度数据
if self.use_depth_image:
frame_data['depths'] = {}
for cam_name, queue in self.sync_depth_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
# 保持原有的边界填充
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
)
# 机械臂数据
for arm_name, queue in self.sync_arm_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
frame_data['arms'][arm_name] = queue.popleft()
# 基座数据
if self.use_robot_base and len(self.sync_base_queue) > 0:
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
self.sync_base_queue.popleft()
frame_data['base'] = self.sync_base_queue.popleft()
return frame_data
def teleop_step(self) -> Optional[tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]:
"""
获取同步帧数据,输出格式与 teleop_step 一致
返回: (obs_dict, action_dict) 或 None 如果同步失败
"""
# 检查基本数据可用性
if any(len(q) == 0 for q in self.sync_img_queues.values()):
return None, None
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
return None, None
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
return None, None
# 计算最小时间戳
timestamps = [
q[-1].header.stamp.to_sec()
for q in list(self.sync_img_queues.values()) +
list(self.sync_arm_queues.values())
]
if self.use_depth_image:
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
if self.use_robot_base and len(self.sync_base_queue) > 0:
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
min_time = min(timestamps)
# 检查数据同步性
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
if queue[-1].header.stamp.to_sec() < min_time:
return None, None
if self.use_depth_image:
for queue in self.sync_depth_queues.values():
if queue[-1].header.stamp.to_sec() < min_time:
return None, None
if self.use_robot_base and len(self.sync_base_queue) > 0:
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
return None, None
# 初始化输出字典
obs_dict = {}
action_dict = {}
# 处理图像数据
for cam_name, queue in self.sync_img_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(img)
# 处理深度数据
if self.use_depth_image:
for cam_name, queue in self.sync_depth_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
depth_img = cv2.copyMakeBorder(
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
)
obs_dict[f"observation.images.depth_{cam_name}"] = torch.from_numpy(depth_img).unsqueeze(-1)
# 处理机械臂观测数据
arm_states = []
arm_velocity = []
arm_effort = []
actions = []
for arm_name, queue in self.sync_arm_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
queue.popleft()
arm_data = queue.popleft()
# np.array(arm_data.position),
# np.array(arm_data.velocity),
# np.array(arm_data.effort)
# 如果是从臂(puppet),作为观测
if arm_name.startswith('puppet'):
arm_states.append(np.array(arm_data.position, dtype=np.float32))
arm_velocity.append(np.array(arm_data.velocity, dtype=np.float32))
arm_effort.append(np.array(arm_data.effort, dtype=np.float32))
# 如果是主臂(master),作为动作
if arm_name.startswith('master'):
# action_dict[f"action.{arm_name}"] = torch.from_numpy(np.array(arm_data.position))
actions.append(np.array(arm_data.position, dtype=np.float32))
if arm_states:
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
if arm_velocity:
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
if arm_effort:
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
if actions:
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
# action_dict["action"] = np.concatenate(actions).squeeze()
# 处理基座数据
if self.use_robot_base and len(self.sync_base_queue) > 0:
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
self.sync_base_queue.popleft()
base_data = self.sync_base_queue.popleft()
obs_dict["observation.base_vel"] = torch.tensor([
base_data.twist.twist.linear.x,
base_data.twist.twist.angular.z
], dtype=torch.float32)
# 添加时间戳
# obs_dict["observation.timestamp"] = torch.tensor(min_time, dtype=torch.float64)
return obs_dict, action_dict
def capture_observation(self):
"""Capture observation data from ROS topics without batch dimension.
Returns:
dict: Observation dictionary containing state and images.
Raises:
RobotDeviceNotConnectedError: If robot is not connected.
"""
# Initialize observation dictionary
obs_dict = {}
# Get synchronized frame data
frame_data = self.get_frame()
if frame_data is None:
# raise RuntimeError("Failed to capture synchronized observation data")
return None
# Process arm state data (from puppet arms)
arm_states = []
arm_velocity = []
arm_effort = []
for arm_name, joint_state in frame_data['arms'].items():
if arm_name.startswith('puppet'):
# Record timing for performance monitoring
before_read_t = time.perf_counter()
# Get position data and convert to tensor
pos = torch.from_numpy(np.array(joint_state.position, dtype=np.float32))
arm_states.append(pos)
velocity = torch.from_numpy(np.array(joint_state.velocity, dtype=np.float32))
arm_velocity.append(velocity)
effort = torch.from_numpy(np.array(joint_state.effort, dtype=np.float32))
arm_effort.append(effort)
# Log timing information
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
# Combine all arm states into single tensor
if arm_states:
obs_dict["observation.state"] = torch.cat(arm_states)
if arm_velocity:
obs_dict["observation.velocity"] = torch.cat(arm_velocity)
if arm_effort:
obs_dict["observation.effort"] = torch.cat(arm_effort)
# Process image data
for cam_name, img in frame_data['images'].items():
# Record timing for performance monitoring
before_camread_t = time.perf_counter()
# Convert image to tensor
img_tensor = torch.from_numpy(img)
obs_dict[f"observation.images.{cam_name}"] = img_tensor
# Log timing information
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
# Process depth data if enabled
if self.use_depth_image and 'depths' in frame_data:
for cam_name, depth_img in frame_data['depths'].items():
before_depthread_t = time.perf_counter()
# Convert depth image to tensor and add channel dimension
depth_tensor = torch.from_numpy(depth_img).unsqueeze(-1)
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
# Process base velocity if enabled
if self.use_robot_base and 'base' in frame_data:
base_data = frame_data['base']
obs_dict["observation.base_vel"] = torch.tensor([
base_data.twist.twist.linear.x,
base_data.twist.twist.angular.z
], dtype=torch.float32)
return obs_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
"""
Send joint position commands to the puppet arms via ROS.
Args:
action: Tensor containing concatenated goal positions for all puppet arms
Shape should match the action space defined in features["action"]
Returns:
The actual action that was sent (may be clipped if safety checks are implemented)
"""
# if not hasattr(self, 'puppet_arm_publishers'):
# # Initialize publishers on first call
# self._init_action_publishers()
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
# Convert tensor to numpy array if needed
if isinstance(action, torch.Tensor):
action = action.detach().cpu().numpy()
# Split action into individual arm commands based on config
from_idx = 0
to_idx = 0
action_sent = []
for arm_name, arm_config in self.arms.items():
# 主臂topic是否存在
if not "master" in arm_name:
continue
# Get number of joints for this arm
num_joints = len(arm_config.get('motors', []))
to_idx += num_joints
# Extract this arm's portion of the action
arm_action = action[from_idx:to_idx]
arm_velocity = last_velocity[from_idx:to_idx]
arm_effort = last_effort[from_idx:to_idx]
from_idx = to_idx
# Apply safety checks if configured
if 'max_relative_target' in self.config:
# Get current position from the queue
if len(self.sync_arm_queues[arm_name]) > 0:
current_state = self.sync_arm_queues[arm_name][-1]
current_pos = np.array(current_state.position)
# Clip the action to stay within max relative target
max_delta = self.config['max_relative_target']
clipped_action = np.clip(arm_action,
current_pos - max_delta,
current_pos + max_delta)
arm_action = clipped_action
action_sent.append(arm_action)
# Create and publish JointState message
joint_state = JointState()
joint_state.header = Header()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', '']
joint_state.position = arm_action.tolist()
joint_state.velocity = arm_velocity
joint_state.effort = arm_effort
# Publish to the corresponding topic
self.publishers[f"arm_{arm_name}"].publish(joint_state)
return torch.from_numpy(np.concatenate(action_sent)) if action_sent else torch.tensor([])
# def _init_action_publishers(self) -> None:
# """Initialize ROS publishers for puppet arms"""
# self.puppet_arm_publishers = {}
# # rospy.init_node("replay_node")
# for arm_name, arm_config in self.arms.items():
# if not "puppet" in arm_name:
# # if not "master" in arm_name:
# continue
# if 'topic_name' not in arm_config:
# rospy.logwarn(f"No puppet topic defined for arm {arm_name}")
# continue
# self.puppet_arm_publishers[arm_name] = rospy.Publisher(
# arm_config['topic_name'],
# JointState,
# queue_size=10
# )
# # Wait for publisher to connect
# rospy.sleep(0.1)
# rospy.loginfo("Initialized puppet arm publishers")
# def get_arguments() -> argparse.Namespace:
# """获取运行时参数"""
# parser = argparse.ArgumentParser()
# parser.add_argument('--fps', type=int, help='Frame rate', default=30)
# parser.add_argument('--max_timesteps', type=int, help='Max timesteps', default=500)
# parser.add_argument('--episode_idx', type=int, help='Episode index', default=0)
# parser.add_argument('--use_depth', action='store_true', help='Use depth images')
# parser.add_argument('--use_base', action='store_true', help='Use robot base')
# return parser.parse_args()
# if __name__ == "__main__":
# # 示例用法
# import json
# args = get_arguments()
# robot = AgilexRobot(config_file="/home/jgl20/LYT/work/collect_data_lerobot/1.yaml", args=args)
# print(json.dumps(robot.features, indent=4))
# robot.warmup_record()
# count = 0
# print_flag = True
# rate = rospy.Rate(args.fps)
# while (count < args.max_timesteps + 1) and not rospy.is_shutdown():
# a, b = robot.teleop_step()
# if a is None or b is None:
# if print_flag:
# print("syn fail\n")
# print_flag = False
# rate.sleep()
# continue
# else:
# print(a)
# timesteps, actions = robot.process()
# print(timesteps)
print()